22 minute read

前几天在看 Mamba 模型的基本原理框架,本来以为已经掌握了,但是后续在看 Mamba-2 的时候,发现一些内容和之前我的理解对不上(尤其是输入和输出的维度部分),所以感觉还是需要结合代码理解比较好。昨天花了一天的时间配环境,踩了不少坑(笨死了!)这里也简单提一嘴好了。然后会对 Mamba 官方代码逐行分析。

ベスト オブ ゴー!ゴー! — GO!GO!7188

环境的配置

一开始是在本地部署的,本地是 Windows11 操作系统,一直配不好,后来看网上的教程和官方 README.md 的要求都是 Linux 系统,于是转移阵地,在课题组的服务器上操作。这个服务器的 CUDA 版本是 12.6,因为 Mamba 官方要求是 CUDA 11.6+ 我担心配太新的版本会有一些版本兼容的问题,并且网上的博客基本使用的也都是较低版本,所以我跟着一篇比较靠谱的博客,选了 CUDA 11.8 版本。

下面是我的环境配置:

Python             3.12.9
cuda               11.8
torch              2.2.2
torchaudio         2.2.2
torchvision        0.17.2
causal_conv1d      1.5.0.post8
mamba_ssm          2.2.4

安装完毕之后总是会出现 ImportError xxxx selective_scan_cuda.cpython-xxx-linux-gnu.so undefined symbol 报错,网上的回复基本都说是版本兼容的问题,我重新装了好几次都不行,最后终于看到一篇博客,查到了问题所在,如果大家在配环境时出现问题可以直接看这篇博客。(真的知道问题出在哪的时候,我感觉自己像个煞笔)一定要注意下载 mamba_ssm 和 causal_conv1d 时看看是 True 还是 False 版本。

代码解读

原代码有两百多行,一下子全部贴出来看起来也费劲,所以这里就分几个部分看。

导入依赖库

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from einops import rearrange, repeat

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn

try:
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
    causal_conv1d_fn, causal_conv1d_update = None, None

try:
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
    selective_state_update = None

try:
    from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None

导入必须的库,这没什么好说的。

类初始化

Mamba 类集成了 nn.Module 模块,这是它的类的初始化,信息量还是比较大的。我们逐步看下去,首先是定义时传入的参数部分:

class Mamba(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        use_fast_path=True,  # Fused kernel options
        layer_idx=None,
        device=None,
        dtype=None,
    ):

模型维度的相关参数

  • d_model: 是模型的嵌入维度,即输入和输出的特征维度。
  • d_state: 是 SSM 的内部状态维度,这意味着模型中间传递的状态维度是 (batch_size, d_state)
  • expand: 是内部扩展因子,Mamba 块内部的扩展层会将 d_model 维度的数据扩展到 d_model * expand 维度进行处理,然后再投影回 d_model

卷积层相关参数

在前一篇博客中,我们有提到卷积的作用是支持并行计算,但是好像这里的卷积层和前面我们提到的不太一样。在 Mamba 块中,SSM 前会有一个短的因果卷积层,用于捕获局部特征和信息。下面是它的参数:

  • d_conv: 是因果卷积层的核大小。
  • conv_bias: 控制因果卷积层是否使用偏置项。

选择性机制相关参数

选择扫描算法是 Mamba 的核心,是其时变特征的体现,所以这一块的参数比较重要也比较多,主要是针对 \(\Delta\) 的设置:

  • dt_rank: 是 \(\Delta\) 参数的秩,其决定了计算 \(\Delta\) 时所使用的线性投影的维度,默认是 d_model / 16 向上取整,较低的秩可以减少参数量,提高效率,但是可能会限制模型的表达能力。
  • dt_min: 是 \(\Delta\) 参数的最小值,确保其不会过小,避免数值不稳定或信息丢失过快。
  • dt_max: 是 \(\Delta\) 参数的最大值,确保其不会过大,防止信息保留过久或数值溢出。
  • dt_init: 是 \(\Delta\) 参数的初始化方法,API 中默认提供 "random""constant" 两种方法
  • dt_scale: 是 \(\Delta\) 初始化值的缩放因子,在 \(\Delta\) 初始化后,会乘以这个缩放因子,用于微调其初始值的分布范围。
  • dt_init_floor: 是 \(\Delta\) 初始化的下限,避免其初始化出现极小值,提高数值稳定性。

其他通用参数

  • bias: 控制 MLP 层和输出投影层是否使用偏置项
  • use_fast_path: 是否使用优化的融合核(fused kernel),之前我们提到了什么是融合核,简单来说就是从底层实现的加速算子,使用之后可以显著提高训练和推理的速度。
  • layer_idx: 当前 Mamba 块的层索引,通常在堆叠多个 Mamba 块时使用,可以为每个块提供唯一的标识。
  • device 和 dtype 就不多说了。
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **factory_kwargs,
        )

        self.activation = "silu"
        self.act = nn.SiLU()

        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        self.dt_proj.bias._no_reinit = True

        # S4D real initialization
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True

        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
        self.D._no_weight_decay = True

        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

上面是类初始化函数里面的内容,初始化一些属性和网络。

  • factory_kwargs

创建了一个字典,用于存储 devicedtype 参数,作为关键字(**factory_kwargs传递给后续创建的 nn.Linearnn.Conv1d 层,确保这些层的权重和偏置都在指定的设备上,并使用指定的数据类型。这是一种方便地将设备和数据类型设置应用到多个模块的方法。

  • self.d_inner

计算内部隐藏层的维度,使用 d_model 乘以 expand 因子得到,即 Mamba 块内部会首先将输入维度扩展到 d_inner 进行处理。

  • self.in_proj

输入投影层,是一个全连接层,输入维度是 self.d_model,输出维度为 self.d_inner * 2,这是因为 Mamba 块的输入通常会被分为两部分,一个用于卷积和 SSM 路径,另一个用于门控——残差连接

  • self.conv1d

定义一个因果一维卷积层,输入输出通道数都是 self.d_inner,其中 groups=self.d_inner 意味着分组卷积,由于组数等于输入和输出的通道数,因此每组的通道数是 1,实际上实现了深度可分离卷积(Depthwise Separable Convolution)中的深度卷积部分,每个输入通道独立地进行卷积,不与其他通道混合。

另外,padding=d_conv - 1 是输入的填充大小,是为了保持输入和输出的大小一致,并且保证了卷积是因果的,即输出的每个元素只依赖于输入中当前及之前的元素。

  • self.act

定义激活层,使用 SiLU 作为激活函数,其表达式为:

\[f(x) = x \cdot \text{sigmoid}(x)\]
  • self.x_proj

定义一个名为 x_proj 的投影层,其从卷积的输出中投影 SSM 的三个关键参数:\(\Delta、B、C\),输入维度是 self.d_inner,输出维度为 self.dt_rank + self.d_state * 2

  • self.dt_proj

定义一个名为 dt_proj 的 \(\Delta\) 投影层,输入维度为 self.dt_rank,输出维度为 self.d_inner,其输出就是控制 SSM 动态的 \(\Delta\) 参数。

接着对该层进行特殊初始化,dt_init_std 是标准差,其值为 self.dt_rank ** -0.5 + dt_scale,如果使用 constant 初始化方法,则该层所有权重都被设置为标准差,如果选择 random 则使用均匀分布初始化权重。

这种精心设置的初始化使得 \(\Delta\) 参数在训练初期就具有合适的尺度,从而稳定训练过程,并提高收敛速度。

  • dt
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
    torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
    + math.log(dt_min)
).clamp(min=dt_init_floor)

这部分代码是在初始化 dt_proj 层的偏置(bias),目的是让经过 F.softplus 激活后的 \(\Delta\) 值(实际上是 dt_proj.bias 经过 softplus 后的值)落在 [dt_min, dt_max] 之间。

(math.log(dt_max) - math.log(dt_min)) + math.log(dt_min): 这一部分是数学技巧,用于将 [0, 1] 的随机数映射到一个对数尺度上的范围,使得其指数化后(torch.exp)的值落在 [dt_min, dt_max] 之间。

.clamp(min=dt_init_floor): 确保所有生成的 dt 值至少为 dt_init_floor,以避免数值不稳定。

这种初始化确保了 \(\Delta\) 参数在训练开始时就处于一个合理的动态范围,这对于 SSM 的稳定性至关重要。

  • inv_dt
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
    self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True

inv_dt = dt + torch.log(-torch.expm1(-dt)): 计算 softplus 函数(\(f(x)=\log(1+e^x)\))的逆函数。因为在 Mamba 的前向传播中,\(\Delta\) 是通过 F.softplus(dt_proj.bias) 得到的,为了让初始化的 dt 值通过 softplus 后能达到目标范围,我们需要将目标 dt 值通过 softplus 的逆函数来初始化 dt_proj.bias

  • expm1(x) 是 \(e^x - 1\),但是数值稳定性更好。是基于泰勒展开计算的。

当 t 接近 0 时,\(e^t\)的值会非常接近 1,在计算机中,浮点数有有限的精度。当一个非常接近 1 的数减去 1 时,就会发生灾难性抵消 (catastrophic cancellation)。这意味着,exp(t) 的许多有效数字在减去 1 后会丢失,导致结果的相对误差非常大,精度急剧下降。

例如,如果 exp(t) 计算结果是 1.0000000000000001 (假设这是双精度浮点数的最小可表示差异),那么 exp(t) - 1 的结果就是 0.0000000000000001。这看起来没问题,但如果 exp(t) 稍微有点误差,比如 1.0000000000000002,那么减去 1 后的结果就是 0.0000000000000002。对于这种非常小的结果,原始的微小误差会变得非常显著,导致结果的相对误差很大。

  • log(-torch.expm1(-dt)) 是 \(\log(-(e^{-dt - 1})) = \log(1-e^{-dt})\)
  • 所以 inv_dt = dt + log(1 - exp(-dt))softplus(x)逆函数,之所以和手推公式有差异,是为了避免 \(e^x\) 导致的溢出,详情可见该博客

with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt): 在不记录梯度的情况下,将计算出的 inv_dt 值复制到 dt_proj 层的偏置项。no_grad() 确保这个初始化操作不会被视为模型计算图的一部分。

self.dt_proj.bias._no_reinit = True: 这是一个自定义属性,通常用于标记这个偏置项已经进行了特殊的初始化,在后续的模型重置或加载预训练权重时,可能需要跳过对其的默认初始化(例如,将所有偏置归零)。

接下来的代码是在初始化 SSM 的核心参数 \(A\)

        # S4D real initialization
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True

repeat 操作构建了一个维度为 (self.d_inner, self.d_state) 的矩阵,每一行的数字从 1 开始到 self.d_state

.contiguous() 是为了确保张量在内存中是连续的,这对于某些操作(尤其是 CUDA 核)很重要。

A_log = torch.log(A): 计算 A 的对数。在 Mamba 的实现中,通常会维护 A 的对数形式,因为 A 参数在 SSM 中会经历指数化操作,在对数空间操作可以提高数值稳定性

self.A_log._no_weight_decay = True: 这是一个自定义属性,标记这个参数不应该应用权重衰减(weight decay)。在 SSM 中,A 参数通常不希望被权重衰减影响,因为它代表了固定的状态空间动态

权重衰减就是一种正则化技术,用于防止过拟合的,通过在损失函数中加入参数的 L2 范数避免参数过大。

  • self.D

初始化 SSM 的 D 参数,也被称为跳跃连接(skip connection)参数,同样设置其不应用权重衰减。

  • self.out_proj

定义一个输出投影层,输入维度为 self.d_inner,输出维度为 self.d_model

总而言之,初始化代码主要是针对 SSM 的参数尤其是 A 和 D 进行特殊初始化,对于 \(\Delta\) 和 B,C,则使用神经网络得到,同时定义了一些需要使用到的网络,例如输入和输出的线性投影层,局部特征提取的因果深度卷积层。

前向传播

Mamba 的前向传播函数是一个精心设计的流程,它将传统的序列处理与选择性状态空间模型相结合。它首先通过线性投影 (in_proj) 扩展输入维度,然后通过一个短卷积捕获局部信息。接着,核心的选择性状态空间模型 (selective_scan_fn) 根据输入动态生成参数 (Δ, B, C),并进行高效的序列扫描,以捕捉长距离依赖。最后,一个门控机制(通过 z 路径)与 SSM 输出结合,并通过输出投影 (out_proj) 将结果映射回原始维度。

为了提高性能,Mamba 大量使用了融合核 (mamba_inner_fn),它将多个操作合并到单个 GPU 核中,从而减少了内存访问和计算开销。在无法使用融合核或进行单步推理时,它会回退到由 PyTorch 模块组成的“慢路径”实现。

下面就是 Mamba 的前向传播函数的原实现:

    def forward(self, hidden_states, inference_params=None):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        batch, seqlen, dim = hidden_states.shape

        conv_state, ssm_state = None, None
        if inference_params is not None:
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset > 0:
                # The states are updated inplace
                out, _, _ = self.step(hidden_states, cnov_state, ssm_state)
                return out

        # We do matmul and transpose BLH -> HBL at the same time
        xz = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), 
            "d (b l) -> b d l",                                                 
            l=seqlen,
        )
        if self.in_proj.bias is not None:
            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
        # In the backward pass we write dx and dz next to each other to avoid torch.cat
        if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:  # Doesn't support outputting the states
            out = mamba_inner_fn(
                xz,
                self.conv1d.weight,
                self.conv1d.bias,
                self.x_proj.weight,
                self.dt_proj.weight,
                self.out_proj.weight,
                self.out_proj.bias,
                A,
                None,  # input-dependent B
                None,  # input-dependent C
                self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
            )
        else:
            x, z = xz.chunk(2, dim=1)
            # Compute short convolution
            if conv_state is not None:
                # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
                # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)
            if causal_conv1d_fn is None:
                x = self.act(self.conv1d(x)[..., :seqlen])
            else:
                assert self.activation in ["silu", "swish"]
                x = causal_conv1d_fn(
                    x=x,
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                )

            # We're careful here about the layout, to avoid extra transposes.
            # We want dt to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
            dt = self.dt_proj.weight @ dt.t()
            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            assert self.activation in ["silu", "swish"]
            y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )
            if ssm_state is not None:
                y, last_state = y
                ssm_state.copy_(last_state)
            y = rearrange(y, "b d l -> b l d")
            out = self.out_proj(y)
        return out

主要解释一些关键的,或者难以理解的步骤吧。

首先是这个前向传播函数接受两个输入,一个是 hidden_states,维度为 (batch_size, len_seq, d_model),另一个是 inference_params,表示推理模式下的参数,主要是为了维护和更新内部状态(如卷积和 SMM 的状态),输出是和输入同维度的张量。

如果是推理模式的话,首先调用内部方法 self._get_states_from_cache 从缓存中获取或初始化状态。inference_params.seqlen_offset > 0 表示模型正在处理序列的中间或末尾部分,这时我们直接执行 self.step 方法更新下一个时间步的状态。

然后我们遇到了一个比较复杂的运算,在运算过程中,矩阵的维度上做了许多的变化,我们细细来讲。

首先是将输入 hidden_states 的维度从 (batch_size, len_seq, d_model) 转化为 (d_model, batch_size * seq_len),然后通过 self.in_proj 层,还记得该层输入为 d_model,输出为 d_inner * 2,所以我们可以知道其 weight 矩阵应该是 (d_inner * 2, d_model),所以矩阵运算得到的结果再 reshape 的结果是 (batch_size, d_inner * 2, seq_len),这就是 xz 的维度。

这一步操作后,下面又加上了 bias 其实就是手动拆解了 self.in_proj(hidden_states),通过显示控制张量的形状变换和矩阵乘法的执行,使得开发者可以精确地调整数据流,以适应某些特定的计算模式或利用底层库的优势,改变内存访问模式,提高性能。

然后我们可以观察到,当我们使用 A_log 来复原 A 时除了取了指数之外,还加了个负号,其确保了 A 是负数,负号是为了确保 S6 模型在离散化后保持稳定性,并控制状态信息的衰减,从而使得模型能够有效地处理长序列并进行稳定的训练。

在连续时间系统中,为了保证系统的稳定性(即避免状态在时间步长上无限增长,导致数值不稳定或发散),A 矩阵的特征值(eigenvalues)的实部必须是负的。当将其离散化后,为了保持稳定性,离散化后的 A_bar 矩阵的特征值的模(magnitude)必须小于 1。

关于这一点,动力系统中有讲过相似的理论来着,可惜我已经忘得差不多了(悲

如果满足使用优化过的融合核的条件,就调用高度优化的 mamba_inner_fn,其为一个单一体的 CUDA 核函数,通过将 Mamba 块的多个操作(卷积、SSM 扫描、门控等)融合在一起,以提高计算效率和减少内存开销。其传入参数部分,由于 B 和 C 参数是根据输入 x 动态生成,所以传入 None,该分支将直接计算 out 并跳过后续的 Pytorch 操作。

否则如果不满足条件,则利用慢路径(Pytorch)实现

首先将矩阵 xz 分成两个形状为 (batch_size, self.d_inner, len_seq) 的矩阵,其中x 用于卷积核 SSM 路径,而 z 用于门控机制。

如果 conv_state 存在,则更新卷积状态,F.pad 用于处理序列长度小于卷积核大小的情况,使用零进行填充,该操作确保了 conv_state 始终包含最近 d_conv-1 个输入元素,以便进行因果卷积。

如果没有可用的优化因果卷积函数,则使用 Pytorch 自带的卷积层处理序列,由于该卷积使用了 padding 操作使得输出和输入的长度一样,所以需要使用截断操作,之后使用激活函数。

而如果可以使用优化因果卷积函数,则调用该函数,注意这里只支持 siluswish 两种激活函数,并且调用函数之后,权重被重新排列,因为 causal_conv1d_fn 可能期望不同的权重布局。

接下来就是比较关键的部分,即通过输入 x 得到参数 dt, B, C,注意这里 x 的维度是 (batch_size, self.d_inner, len_seq),我们首先将 x 进行重排得到 (batch_size * len_seq, self.d_inner) 的矩阵,然后通过 x_proj 投影层得到维度为 (batch_size * len_seq, dt_rank + d_state * 2) 的输出结果,然后将它分成三部分分别赋值给 dt, B, C

随后使用 dt_proj.weightdt 的转置进行计算,得到的结果为 (d_inner, dt_rank) x (dt_rank, batch_size * len_seq) -> (d_inner, batch_size * len_seq),最后将三个参数进行维度变化,使其符合要求。三者最终的维度为 (batch_size, d_inner / d_state, len_seq)

接着执行选择性扫描操作,这一部分的实现到时候再讲,输出的 y 应该是个列表,其中有两个元素,一个是所有状态的张量,维度是 (batch_size, d_inner, len_seq),另一个是最后一个时刻的状态,维度是 (batch_size, d_state)ssm_state.copy_(last_state)是为了将更新后的 last_state 复制回 ssm_state,以便在下一个时间步使用。

最后我们将 y 进行 reshape 操作,变成维度为 (batch_size, len_seq, d_inner) 的矩阵,然后通过输出线性层 self.out_proj 得到最终结果 out,其维度为 (batch_size, len_seq, d_inner)

单步推理

Mamba 模型定义了 step 方法,主要用于单步推理,每次处理一个时间步的输入并更新模型的内部状态。

step 方法是 Mamba 模型在推理模式下进行自回归生成的关键。它实现了 Mamba 块的单时间步计算:

  1. 输入投影: 处理单个 token 的输入并将其分解为 x 和 z。
  2. 卷积状态更新: 更新并应用因果卷积,利用 conv_state 来记忆局部历史信息。
  3. SSM 参数生成: 根据当前输入动态生成 SSM 参数 dt、B、C
  4. SSM 状态更新和输出: 使用新的参数和旧的 ssm_state 来更新 SSM 内部状态,并计算当前时间步的 SSM 输出。
  5. 门控和残差连接: 将 SSM 输出与门控机制 (z) 和 D (skip) 连接结合。
  6. 输出投影: 将最终结果投影回模型维度。
    def step(self, hidden_states, conv_state, ssm_state):
        dtype = hidden_states.dtype
        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
        xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)
        x, z = xz.chunk(2, dim=-1)  # (B D)

        # Conv step
        if causal_conv1d_update is None:
            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)
            conv_state[:, :, -1] = x
            x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)  # (B D) 
            # PyTorch nn.Conv1d 层的权重维度标准形式是:(out_channels, in_channels / groups, kernel_size)
            if self.conv1d.bias is not None:
                x = x + self.conv1d.bias
            x = self.act(x).to(dtype=dtype)
        else:
            x = causal_conv1d_update(
                x,
                conv_state,
                rearrange(self.conv1d.weight, "d 1 w -> d w"),
                self.conv1d.bias,
                self.activation,
            )

        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)
        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        # Don't add dt_bias here
        dt = F.linear(dt, self.dt_proj.weight)  # (B d_inner)
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

        # SSM step
        if selective_state_update is None:
            # Discretize A and B
            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
            dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
            dB = torch.einsum("bd,bn->bdn", dt, B)
            ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
            y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
            y = y + self.D.to(dtype) * x
            y = y * self.act(z)  # (B D)
        else:
            y = selective_state_update(
                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
            )

        out = self.out_proj(y)
        return out.unsqueeze(1), conv_state, ssm_state

step 方法接受三个输入,分别是当前时间步的输入 hidden_states,以及上一个时间步的卷积状态 conv_state 和 SSM 状态 ssm_state

  • 输入投影 (in_proj)

hidden_states.squeeze(1):将输入 hidden_states(B, 1, D) 形状中移除长度维度,变为 (B, D)self.in_proj(...):通过 in_proj 线性层进行投影,输出维度是 self.d_inner * 2。因此 xz 的形状是 (B, self.d_inner * 2)x, z = xz.chunk(2, dim=-1): 将 xz 沿最后一个维度分成两部分: x (形状 (B, self.d_inner)) 用于卷积和 SSM 路径。 z (形状 (B, self.d_inner)) 用于门控。

  • 卷积步进 (Conv Step)

如果没有高度优化的 causal_conv1d_update 函数(例如在 CPU 上):

conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)): 更新卷积状态。conv_state 的形状是 (B, self.d_inner, d_conv-1) (这里假设 conv_state 预先初始化为 d_conv-1 长度)。torch.roll 会将 conv_state 向左(负方向)平移一位,最左边的元素被移除,为新输入腾出空间。

conv_state[:, :, -1] = x:将当前时间步的输入 x 放入 conv_state 的最右端(最新的位置)。

x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1):执行卷积操作。conv_state 包含了当前输入和之前的 d_conv-1 个输入。将其与重排后的卷积核 (self.conv1d.weight) 相乘,并在最后一个维度求和,得到卷积输出。

if self.conv1d.bias is not None: x = x + self.conv1d.bias:如果卷积层有偏置,则加上偏置。

x = self.act(x).to(dtype=dtype):应用激活函数 self.act (SiLU),并确保数据类型一致。

self.conv1d.weight.shape = (out_channels, in_channels // groups, kernel_size) 因为使用了分组卷积,所以参数量会下降。这里 groups = in_channels = d_inner 所以属于深度卷积。

  • 投影出 SSM 参数 (dt, B, C)

x_db = self.x_proj(x):将卷积输出 x (形状 (B, self.d_inner)) 通过 x_proj 线性层投影,输出形状是 (B, self.dt_rank + self.d_state * 2)

dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1): 将 x_db 分割为 dt, B, C

dt = F.linear(dt, self.dt_proj.weight): 将 dt (形状 (B, self.dt_rank)) 与 dt_proj 的权重 (形状 (self.d_inner, self.dt_rank)) 进行线性变换。结果 dt 的形状是 (B, self.d_inner)。注意这里的注释 “Don’t add dt_bias here”,这意味着偏置将在后面的 softplus 激活中添加。

  • SSM 步进 (SSM Step)

同样地,这里可以直接调用高度优化的选择状态更新函数,如果无法调用的话,就只能用 CPU 或 GPU 跑了。

dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)):将 dt 加上 dt_proj 的偏置,然后通过 softplus 激活函数,确保 dt 为正。

torch.einsum("bd,dn->bdn", dt, A) 是爱因斯坦求和约定,这里表示对 dt (形状 (B, D_inner)) 和 A (形状 (D_inner, D_state)) 进行逐元素的乘法,并扩展维度以匹配。结果 (B, D_inner, D_state)

torch.einsum("bd,bn->bdn", dt, B)dt (形状 (B, D_inner)) 和 B (形状 (B, D_state)) 进行逐元素的乘法。结果为 dB 维度为 (B, D_inner, D_state)

ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)更新 SSM 状态。 这是核心的状态空间方程,ssm_state * dA: 衰减旧状态。rearrange(x, "b d -> b d 1") * dB: 将当前输入映射到状态空间。

将这两部分相加得到新的 ssm_state,并将其复制回传入的 ssm_state 变量(原地更新)。

y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C): 计算当前时间步的输出 y。将更新后的 ssm_state 与参数 C 相乘。

最终返回三个变量,分别是:

  1. out: (batch_size, 1, d_inner)
  2. conv_state: 更新后的卷积状态。
  3. ssm_state: 更新后的 SSM 状态。

状态初始化函数

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        device = self.out_proj.weight.device
        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
        conv_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
        )
        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
        # ssm_dtype = torch.float32
        ssm_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
        )
        return conv_state, ssm_state

从缓存中获取状态

    def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
        assert self.layer_idx is not None
        if self.layer_idx not in inference_params.key_value_memory_dict:
            batch_shape = (batch_size,)
            conv_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_conv,
                device=self.conv1d.weight.device,
                dtype=self.conv1d.weight.dtype,
            )
            ssm_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_state,
                device=self.dt_proj.weight.device,
                dtype=self.dt_proj.weight.dtype,
                # dtype=torch.float32,
            )
            inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
        else:
            conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
            # TODO: What if batch size changes between generation, and we reuse the same states?
            if initialize_states:
                conv_state.zero_()
                ssm_state.zero_()
        return conv_state, ssm_state

该方法主要用在推理或生成过程中,获取和管理 Mamba 块的内部状态。

首先检查状态是否存在,它使用 inference_params.key_value_memory_dict 字典来存储和检索状态。self.layer_idx 作为键,每个 Mamba 层都有自己的状态。如果 self.layer_idx 不在字典中,说明这是第一次访问该层的状态,需要初始化。

如果状态不存在就初始化,否则就直接返回检索结果,可以在调用时选择是否重置状态。

总结

以上就是 Mamba 这个类的定义了,感觉自己写得有点屎山,主要还是靠 ChatGPT 和 Gemini 老师讲解的,我只能算是个翻译官啊哈哈。

关于更加底层的函数,例如 mamba_inner_fncausal_conv1d_fnselective_state_updatecausal_conv1d_updateselective_scan_fn 等等,这些后续会补充其具体实现。

另外后续还会从同理一遍整个 Mamba 的运行逻辑,To be continued……

Leave a comment