Mamba 代码详解
前几天在看 Mamba 模型的基本原理框架,本来以为已经掌握了,但是后续在看 Mamba-2 的时候,发现一些内容和之前我的理解对不上(尤其是输入和输出的维度部分),所以感觉还是需要结合代码理解比较好。昨天花了一天的时间配环境,踩了不少坑(笨死了!)这里也简单提一嘴好了。然后会对 Mamba 官方代码逐行分析。

环境的配置
一开始是在本地部署的,本地是 Windows11 操作系统,一直配不好,后来看网上的教程和官方 README.md 的要求都是 Linux 系统,于是转移阵地,在课题组的服务器上操作。这个服务器的 CUDA 版本是 12.6,因为 Mamba 官方要求是 CUDA 11.6+ 我担心配太新的版本会有一些版本兼容的问题,并且网上的博客基本使用的也都是较低版本,所以我跟着一篇比较靠谱的博客,选了 CUDA 11.8 版本。
下面是我的环境配置:
Python 3.12.9
cuda 11.8
torch 2.2.2
torchaudio 2.2.2
torchvision 0.17.2
causal_conv1d 1.5.0.post8
mamba_ssm 2.2.4
安装完毕之后总是会出现 ImportError xxxx selective_scan_cuda.cpython-xxx-linux-gnu.so undefined symbol
报错,网上的回复基本都说是版本兼容的问题,我重新装了好几次都不行,最后终于看到一篇博客,查到了问题所在,如果大家在配环境时出现问题可以直接看这篇博客。(真的知道问题出在哪的时候,我感觉自己像个煞笔)一定要注意下载 mamba_ssm 和 causal_conv1d 时看看是 True 还是 False 版本。
代码解读
原代码有两百多行,一下子全部贴出来看起来也费劲,所以这里就分几个部分看。
导入依赖库
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
导入必须的库,这没什么好说的。
类初始化
Mamba 类集成了 nn.Module 模块,这是它的类的初始化,信息量还是比较大的。我们逐步看下去,首先是定义时传入的参数部分:
class Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
):
模型维度的相关参数:
- d_model: 是模型的嵌入维度,即输入和输出的特征维度。
- d_state: 是 SSM 的内部状态维度,这意味着模型中间传递的状态维度是
(batch_size, d_state)
- expand: 是内部扩展因子,Mamba 块内部的扩展层会将
d_model
维度的数据扩展到d_model * expand
维度进行处理,然后再投影回d_model
卷积层相关参数:
在前一篇博客中,我们有提到卷积的作用是支持并行计算,但是好像这里的卷积层和前面我们提到的不太一样。在 Mamba 块中,SSM 前会有一个短的因果卷积层,用于捕获局部特征和信息。下面是它的参数:
- d_conv: 是因果卷积层的核大小。
- conv_bias: 控制因果卷积层是否使用偏置项。
选择性机制相关参数
选择扫描算法是 Mamba 的核心,是其时变特征的体现,所以这一块的参数比较重要也比较多,主要是针对 \(\Delta\) 的设置:
- dt_rank: 是 \(\Delta\) 参数的秩,其决定了计算 \(\Delta\) 时所使用的线性投影的维度,默认是
d_model / 16
向上取整,较低的秩可以减少参数量,提高效率,但是可能会限制模型的表达能力。 - dt_min: 是 \(\Delta\) 参数的最小值,确保其不会过小,避免数值不稳定或信息丢失过快。
- dt_max: 是 \(\Delta\) 参数的最大值,确保其不会过大,防止信息保留过久或数值溢出。
- dt_init: 是 \(\Delta\) 参数的初始化方法,API 中默认提供
"random"
和"constant"
两种方法 - dt_scale: 是 \(\Delta\) 初始化值的缩放因子,在 \(\Delta\) 初始化后,会乘以这个缩放因子,用于微调其初始值的分布范围。
- dt_init_floor: 是 \(\Delta\) 初始化的下限,避免其初始化出现极小值,提高数值稳定性。
其他通用参数
- bias: 控制 MLP 层和输出投影层是否使用偏置项
- use_fast_path: 是否使用优化的融合核(fused kernel),之前我们提到了什么是融合核,简单来说就是从底层实现的加速算子,使用之后可以显著提高训练和推理的速度。
- layer_idx: 当前 Mamba 块的层索引,通常在堆叠多个 Mamba 块时使用,可以为每个块提供唯一的标识。
- device 和 dtype 就不多说了。
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
self.activation = "silu"
self.act = nn.SiLU()
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
上面是类初始化函数里面的内容,初始化一些属性和网络。
- factory_kwargs
创建了一个字典,用于存储 device
和 dtype
参数,作为关键字(**factory_kwargs
)传递给后续创建的 nn.Linear
和 nn.Conv1d
层,确保这些层的权重和偏置都在指定的设备上,并使用指定的数据类型。这是一种方便地将设备和数据类型设置应用到多个模块的方法。
- self.d_inner
计算内部隐藏层的维度,使用 d_model
乘以 expand
因子得到,即 Mamba 块内部会首先将输入维度扩展到 d_inner
进行处理。
- self.in_proj
输入投影层,是一个全连接层,输入维度是 self.d_model
,输出维度为 self.d_inner * 2
,这是因为 Mamba 块的输入通常会被分为两部分,一个用于卷积和 SSM 路径,另一个用于门控——残差连接
- self.conv1d
定义一个因果一维卷积层,输入输出通道数都是 self.d_inner
,其中 groups=self.d_inner
意味着分组卷积,由于组数等于输入和输出的通道数,因此每组的通道数是 1,实际上实现了深度可分离卷积(Depthwise Separable Convolution)中的深度卷积部分,每个输入通道独立地进行卷积,不与其他通道混合。
另外,padding=d_conv - 1
是输入的填充大小,是为了保持输入和输出的大小一致,并且保证了卷积是因果的,即输出的每个元素只依赖于输入中当前及之前的元素。
- self.act
定义激活层,使用 SiLU
作为激活函数,其表达式为:
- self.x_proj
定义一个名为 x_proj
的投影层,其从卷积的输出中投影 SSM 的三个关键参数:\(\Delta、B、C\),输入维度是 self.d_inner
,输出维度为 self.dt_rank + self.d_state * 2
- self.dt_proj
定义一个名为 dt_proj
的 \(\Delta\) 投影层,输入维度为 self.dt_rank
,输出维度为 self.d_inner
,其输出就是控制 SSM 动态的 \(\Delta\) 参数。
接着对该层进行特殊初始化,dt_init_std
是标准差,其值为 self.dt_rank ** -0.5 + dt_scale
,如果使用 constant
初始化方法,则该层所有权重都被设置为标准差,如果选择 random
则使用均匀分布初始化权重。
这种精心设置的初始化使得 \(\Delta\) 参数在训练初期就具有合适的尺度,从而稳定训练过程,并提高收敛速度。
- dt
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
这部分代码是在初始化 dt_proj
层的偏置(bias
),目的是让经过 F.softplus
激活后的 \(\Delta\) 值(实际上是 dt_proj.bias
经过 softplus
后的值)落在 [dt_min, dt_max]
之间。
(math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
: 这一部分是数学技巧,用于将 [0, 1]
的随机数映射到一个对数尺度上的范围,使得其指数化后(torch.exp
)的值落在 [dt_min, dt_max]
之间。
.clamp(min=dt_init_floor)
: 确保所有生成的 dt
值至少为 dt_init_floor
,以避免数值不稳定。
这种初始化确保了 \(\Delta\) 参数在训练开始时就处于一个合理的动态范围,这对于 SSM 的稳定性至关重要。
- inv_dt
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
inv_dt = dt + torch.log(-torch.expm1(-dt))
: 计算 softplus
函数(\(f(x)=\log(1+e^x)\))的逆函数。因为在 Mamba 的前向传播中,\(\Delta\) 是通过 F.softplus(dt_proj.bias)
得到的,为了让初始化的 dt
值通过 softplus
后能达到目标范围,我们需要将目标 dt
值通过 softplus
的逆函数来初始化 dt_proj.bias
。
expm1(x)
是 \(e^x - 1\),但是数值稳定性更好。是基于泰勒展开计算的。
当 t 接近 0 时,\(e^t\)的值会非常接近 1,在计算机中,浮点数有有限的精度。当一个非常接近 1 的数减去 1 时,就会发生灾难性抵消 (catastrophic cancellation)。这意味着,exp(t) 的许多有效数字在减去 1 后会丢失,导致结果的相对误差非常大,精度急剧下降。
例如,如果 exp(t) 计算结果是 1.0000000000000001 (假设这是双精度浮点数的最小可表示差异),那么 exp(t) - 1 的结果就是 0.0000000000000001。这看起来没问题,但如果 exp(t) 稍微有点误差,比如 1.0000000000000002,那么减去 1 后的结果就是 0.0000000000000002。对于这种非常小的结果,原始的微小误差会变得非常显著,导致结果的相对误差很大。
log(-torch.expm1(-dt))
是 \(\log(-(e^{-dt - 1})) = \log(1-e^{-dt})\)- 所以
inv_dt = dt + log(1 - exp(-dt))
是softplus(x)
的逆函数,之所以和手推公式有差异,是为了避免 \(e^x\) 导致的溢出,详情可见该博客
with torch.no_grad(): self.dt_proj.bias.copy_(inv_dt)
: 在不记录梯度的情况下,将计算出的 inv_dt
值复制到 dt_proj
层的偏置项。no_grad()
确保这个初始化操作不会被视为模型计算图的一部分。
self.dt_proj.bias._no_reinit = True
: 这是一个自定义属性,通常用于标记这个偏置项已经进行了特殊的初始化,在后续的模型重置或加载预训练权重时,可能需要跳过对其的默认初始化(例如,将所有偏置归零)。
接下来的代码是在初始化 SSM 的核心参数 \(A\)
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
repeat 操作构建了一个维度为 (self.d_inner, self.d_state)
的矩阵,每一行的数字从 1 开始到 self.d_state
.contiguous()
是为了确保张量在内存中是连续的,这对于某些操作(尤其是 CUDA 核)很重要。
A_log = torch.log(A)
: 计算 A 的对数。在 Mamba 的实现中,通常会维护 A 的对数形式,因为 A 参数在 SSM 中会经历指数化操作,在对数空间操作可以提高数值稳定性。
self.A_log._no_weight_decay = True
: 这是一个自定义属性,标记这个参数不应该应用权重衰减(weight decay)。在 SSM 中,A 参数通常不希望被权重衰减影响,因为它代表了固定的状态空间动态。
权重衰减就是一种正则化技术,用于防止过拟合的,通过在损失函数中加入参数的 L2 范数避免参数过大。
- self.D
初始化 SSM 的 D 参数,也被称为跳跃连接(skip connection)参数,同样设置其不应用权重衰减。
- self.out_proj
定义一个输出投影层,输入维度为 self.d_inner
,输出维度为 self.d_model
总而言之,初始化代码主要是针对 SSM 的参数尤其是 A 和 D 进行特殊初始化,对于 \(\Delta\) 和 B,C,则使用神经网络得到,同时定义了一些需要使用到的网络,例如输入和输出的线性投影层,局部特征提取的因果深度卷积层。
前向传播
Mamba 的前向传播函数是一个精心设计的流程,它将传统的序列处理与选择性状态空间模型相结合。它首先通过线性投影 (in_proj
) 扩展输入维度,然后通过一个短卷积捕获局部信息。接着,核心的选择性状态空间模型 (selective_scan_fn
) 根据输入动态生成参数 (Δ, B, C),并进行高效的序列扫描,以捕捉长距离依赖。最后,一个门控机制(通过 z
路径)与 SSM 输出结合,并通过输出投影 (out_proj
) 将结果映射回原始维度。
为了提高性能,Mamba 大量使用了融合核 (mamba_inner_fn
),它将多个操作合并到单个 GPU 核中,从而减少了内存访问和计算开销。在无法使用融合核或进行单步推理时,它会回退到由 PyTorch 模块组成的“慢路径”实现。
下面就是 Mamba 的前向传播函数的原实现:
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
batch, seqlen, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, cnov_state, ssm_state)
return out
# We do matmul and transpose BLH -> HBL at the same time
xz = rearrange(
self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
"d (b l) -> b d l",
l=seqlen,
)
if self.in_proj.bias is not None:
xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# In the backward pass we write dx and dz next to each other to avoid torch.cat
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
out = mamba_inner_fn(
xz,
self.conv1d.weight,
self.conv1d.bias,
self.x_proj.weight,
self.dt_proj.weight,
self.out_proj.weight,
self.out_proj.bias,
A,
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
else:
x, z = xz.chunk(2, dim=1)
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# We're careful here about the layout, to avoid extra transposes.
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
assert self.activation in ["silu", "swish"]
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
return out
主要解释一些关键的,或者难以理解的步骤吧。
首先是这个前向传播函数接受两个输入,一个是 hidden_states
,维度为 (batch_size, len_seq, d_model)
,另一个是 inference_params
,表示推理模式下的参数,主要是为了维护和更新内部状态(如卷积和 SMM 的状态),输出是和输入同维度的张量。
如果是推理模式的话,首先调用内部方法 self._get_states_from_cache
从缓存中获取或初始化状态。inference_params.seqlen_offset > 0
表示模型正在处理序列的中间或末尾部分,这时我们直接执行 self.step
方法更新下一个时间步的状态。
然后我们遇到了一个比较复杂的运算,在运算过程中,矩阵的维度上做了许多的变化,我们细细来讲。
首先是将输入 hidden_states
的维度从 (batch_size, len_seq, d_model)
转化为 (d_model, batch_size * seq_len)
,然后通过 self.in_proj
层,还记得该层输入为 d_model
,输出为 d_inner * 2
,所以我们可以知道其 weight
矩阵应该是 (d_inner * 2, d_model)
,所以矩阵运算得到的结果再 reshape 的结果是 (batch_size, d_inner * 2, seq_len)
,这就是 xz
的维度。
这一步操作后,下面又加上了 bias 其实就是手动拆解了
self.in_proj(hidden_states)
,通过显示控制张量的形状变换和矩阵乘法的执行,使得开发者可以精确地调整数据流,以适应某些特定的计算模式或利用底层库的优势,改变内存访问模式,提高性能。
然后我们可以观察到,当我们使用 A_log
来复原 A 时除了取了指数之外,还加了个负号,其确保了 A 是负数,负号是为了确保 S6 模型在离散化后保持稳定性,并控制状态信息的衰减,从而使得模型能够有效地处理长序列并进行稳定的训练。
在连续时间系统中,为了保证系统的稳定性(即避免状态在时间步长上无限增长,导致数值不稳定或发散),A 矩阵的特征值(eigenvalues)的实部必须是负的。当将其离散化后,为了保持稳定性,离散化后的 A_bar 矩阵的特征值的模(magnitude)必须小于 1。
关于这一点,动力系统中有讲过相似的理论来着,可惜我已经忘得差不多了(悲
如果满足使用优化过的融合核的条件,就调用高度优化的 mamba_inner_fn
,其为一个单一体的 CUDA 核函数,通过将 Mamba 块的多个操作(卷积、SSM 扫描、门控等)融合在一起,以提高计算效率和减少内存开销。其传入参数部分,由于 B 和 C 参数是根据输入 x 动态生成,所以传入 None
,该分支将直接计算 out
并跳过后续的 Pytorch 操作。
否则如果不满足条件,则利用慢路径(Pytorch)实现。
首先将矩阵 xz
分成两个形状为 (batch_size, self.d_inner, len_seq)
的矩阵,其中x
用于卷积核 SSM 路径,而 z
用于门控机制。
如果 conv_state
存在,则更新卷积状态,F.pad
用于处理序列长度小于卷积核大小的情况,使用零进行填充,该操作确保了 conv_state
始终包含最近 d_conv-1
个输入元素,以便进行因果卷积。
如果没有可用的优化因果卷积函数,则使用 Pytorch 自带的卷积层处理序列,由于该卷积使用了 padding 操作使得输出和输入的长度一样,所以需要使用截断操作,之后使用激活函数。
而如果可以使用优化因果卷积函数,则调用该函数,注意这里只支持 silu
和 swish
两种激活函数,并且调用函数之后,权重被重新排列,因为 causal_conv1d_fn
可能期望不同的权重布局。
接下来就是比较关键的部分,即通过输入 x
得到参数 dt, B, C
,注意这里 x
的维度是 (batch_size, self.d_inner, len_seq)
,我们首先将 x 进行重排得到 (batch_size * len_seq, self.d_inner)
的矩阵,然后通过 x_proj
投影层得到维度为 (batch_size * len_seq, dt_rank + d_state * 2)
的输出结果,然后将它分成三部分分别赋值给 dt, B, C
。
随后使用 dt_proj.weight
和 dt
的转置进行计算,得到的结果为 (d_inner, dt_rank) x (dt_rank, batch_size * len_seq) -> (d_inner, batch_size * len_seq)
,最后将三个参数进行维度变化,使其符合要求。三者最终的维度为 (batch_size, d_inner / d_state, len_seq)
接着执行选择性扫描操作,这一部分的实现到时候再讲,输出的 y
应该是个列表,其中有两个元素,一个是所有状态的张量,维度是 (batch_size, d_inner, len_seq)
,另一个是最后一个时刻的状态,维度是 (batch_size, d_state)
,ssm_state.copy_(last_state)
是为了将更新后的 last_state
复制回 ssm_state
,以便在下一个时间步使用。
最后我们将 y
进行 reshape 操作,变成维度为 (batch_size, len_seq, d_inner)
的矩阵,然后通过输出线性层 self.out_proj
得到最终结果 out
,其维度为 (batch_size, len_seq, d_inner)
单步推理
Mamba 模型定义了 step
方法,主要用于单步推理,每次处理一个时间步的输入并更新模型的内部状态。
step 方法是 Mamba 模型在推理模式下进行自回归生成的关键。它实现了 Mamba 块的单时间步计算:
- 输入投影: 处理单个 token 的输入并将其分解为 x 和 z。
- 卷积状态更新: 更新并应用因果卷积,利用
conv_state
来记忆局部历史信息。 - SSM 参数生成: 根据当前输入动态生成 SSM 参数
dt、B、C
。 - SSM 状态更新和输出: 使用新的参数和旧的
ssm_state
来更新 SSM 内部状态,并计算当前时间步的 SSM 输出。 - 门控和残差连接: 将 SSM 输出与门控机制 (
z
) 和 D (skip) 连接结合。 - 输出投影: 将最终结果投影回模型维度。
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
x, z = xz.chunk(2, dim=-1) # (B D)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
# PyTorch nn.Conv1d 层的权重维度标准形式是:(out_channels, in_channels / groups, kernel_size)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
# Don't add dt_bias here
dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
# SSM step
if selective_state_update is None:
# Discretize A and B
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB = torch.einsum("bd,bn->bdn", dt, B)
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
y = y + self.D.to(dtype) * x
y = y * self.act(z) # (B D)
else:
y = selective_state_update(
ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
)
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
step
方法接受三个输入,分别是当前时间步的输入 hidden_states
,以及上一个时间步的卷积状态 conv_state
和 SSM 状态 ssm_state
- 输入投影 (in_proj)
hidden_states.squeeze(1)
:将输入 hidden_states
从 (B, 1, D)
形状中移除长度维度,变为 (B, D)
。
self.in_proj(...)
:通过 in_proj
线性层进行投影,输出维度是 self.d_inner * 2
。因此 xz
的形状是 (B, self.d_inner * 2)
。
x, z = xz.chunk(2, dim=-1)
: 将 xz 沿最后一个维度分成两部分:
x
(形状 (B, self.d_inner)
) 用于卷积和 SSM 路径。
z
(形状 (B, self.d_inner)
) 用于门控。
- 卷积步进 (Conv Step)
如果没有高度优化的 causal_conv1d_update
函数(例如在 CPU 上):
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
: 更新卷积状态。conv_state
的形状是 (B, self.d_inner, d_conv-1)
(这里假设 conv_state
预先初始化为 d_conv-1
长度)。torch.roll
会将 conv_state
向左(负方向)平移一位,最左边的元素被移除,为新输入腾出空间。
conv_state[:, :, -1] = x
:将当前时间步的输入 x
放入 conv_state
的最右端(最新的位置)。
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)
:执行卷积操作。conv_state
包含了当前输入和之前的 d_conv-1
个输入。将其与重排后的卷积核 (self.conv1d.weight)
相乘,并在最后一个维度求和,得到卷积输出。
if self.conv1d.bias is not None: x = x + self.conv1d.bias
:如果卷积层有偏置,则加上偏置。
x = self.act(x).to(dtype=dtype)
:应用激活函数 self.act
(SiLU),并确保数据类型一致。
self.conv1d.weight.shape = (out_channels, in_channels // groups, kernel_size)
因为使用了分组卷积,所以参数量会下降。这里groups = in_channels = d_inner
所以属于深度卷积。
- 投影出 SSM 参数 (dt, B, C)
x_db = self.x_proj(x)
:将卷积输出 x
(形状 (B, self.d_inner)
) 通过 x_proj
线性层投影,输出形状是 (B, self.dt_rank + self.d_state * 2)
。
dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
: 将 x_db
分割为 dt, B, C
。
dt = F.linear(dt, self.dt_proj.weight)
: 将 dt
(形状 (B, self.dt_rank)
) 与 dt_proj
的权重 (形状 (self.d_inner, self.dt_rank)
) 进行线性变换。结果 dt
的形状是 (B, self.d_inner)
。注意这里的注释 “Don’t add dt_bias here”,这意味着偏置将在后面的 softplus 激活中添加。
- SSM 步进 (SSM Step)
同样地,这里可以直接调用高度优化的选择状态更新函数,如果无法调用的话,就只能用 CPU 或 GPU 跑了。
dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
:将 dt
加上 dt_proj
的偏置,然后通过 softplus
激活函数,确保 dt
为正。
torch.einsum("bd,dn->bdn", dt, A)
是爱因斯坦求和约定,这里表示对 dt
(形状 (B, D_inner)
) 和 A (形状 (D_inner, D_state)
) 进行逐元素的乘法,并扩展维度以匹配。结果 (B, D_inner, D_state)
。
torch.einsum("bd,bn->bdn", dt, B)
对 dt
(形状 (B, D_inner)
) 和 B
(形状 (B, D_state)
) 进行逐元素的乘法。结果为 dB
维度为 (B, D_inner, D_state)
。
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
:更新 SSM 状态。 这是核心的状态空间方程,ssm_state * dA
: 衰减旧状态。rearrange(x, "b d -> b d 1") * dB
: 将当前输入映射到状态空间。
将这两部分相加得到新的 ssm_state
,并将其复制回传入的 ssm_state
变量(原地更新)。
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
: 计算当前时间步的输出 y
。将更新后的 ssm_state
与参数 C
相乘。
最终返回三个变量,分别是:
out
:(batch_size, 1, d_inner)
conv_state
: 更新后的卷积状态。ssm_state
: 更新后的 SSM 状态。
状态初始化函数
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
# ssm_dtype = torch.float32
ssm_state = torch.zeros(
batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
)
return conv_state, ssm_state
从缓存中获取状态
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
batch_shape = (batch_size,)
conv_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.d_model * self.expand,
self.d_state,
device=self.dt_proj.weight.device,
dtype=self.dt_proj.weight.dtype,
# dtype=torch.float32,
)
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
该方法主要用在推理或生成过程中,获取和管理 Mamba 块的内部状态。
首先检查状态是否存在,它使用 inference_params.key_value_memory_dict
字典来存储和检索状态。self.layer_idx
作为键,每个 Mamba 层都有自己的状态。如果 self.layer_idx
不在字典中,说明这是第一次访问该层的状态,需要初始化。
如果状态不存在就初始化,否则就直接返回检索结果,可以在调用时选择是否重置状态。
总结
以上就是 Mamba 这个类的定义了,感觉自己写得有点屎山,主要还是靠 ChatGPT 和 Gemini 老师讲解的,我只能算是个翻译官啊哈哈。
关于更加底层的函数,例如 mamba_inner_fn
,causal_conv1d_fn
,selective_state_update
,causal_conv1d_update
,selective_scan_fn
等等,这些后续会补充其具体实现。
另外后续还会从同理一遍整个 Mamba 的运行逻辑,To be continued……
Leave a comment