Mamaba 模型学习记录,原视频:AI大讲堂:革了Transformer的小命?专业拆解【Mamba模型】_哔哩哔哩_bilibili

Transformer 的问题

transformer 架构

Transformer 模型的核心可以说是注意力机制,但是这个注意力机制有一个致命的缺点:为了平衡计算量和更好的处理长序列任务,通常会限制注意力的窗口大小。但这样就无法关注窗口以外的内容了,这就一定程度上造成了视野狭窄,不能关注到所有的文本内容。
虽然说理论上增大窗口的大小一定程度上可以解决这个问题,但是随着窗口的增大,计算的复杂度会呈平方增长 O(n^2),因为窗口中的每个元素,都要和窗口内的所有元素进行注意力分数计算(包括自身)。
视频中提到 Transformer 本质上通过位置编码,把序列数据空间化,通过计算空间相关度反向建模时序相关度。但是呢,这种方法也忽略了数据内在结构的关联(如语法信息),采用一视同仁的暴力关联模式,虽然直接简单,但是效率低下,冗余度高,不易训练。
对于时序数据,但是采用了空间化实现注意力机制,在当年虽然为了充分利用 GPU 的并行能力,很有效,但并不是万能的,也是有问题的。从某种程度上讲,SSM 类模型思考问题的初衷和视角,就是让长序列数据建模回归传统,Mamba 便是其中的佼佼者。

时序状态空间模型 SSM

Mamba 是基于结构化状态空间序列模型(SSMs)的,Combining Recurrent, Convolutional, and Continuous-time Models with Linear State Space Layers 这是 2021 年提出的工作,模型还是时序的,本质上依然是一个 RNN。
LSSL 模型示意图

1. 连续空间的时序建模

在控制理论、信号处理或者线性系统等领域,连续空间状态模型都可以用来建模解决时序问题。被称为 Linear tiem-invariant(LTI)系统,公式如下:

h(t)=Ah(t)+Bx(t)y(t)=Ch(t)\begin{align} & h'(t)=Ah(t)+Bx(t) \\ & y(t)=Ch(t) \end{align}

上方为状态方程:对于输入 xx,先乘输入矩阵 BB,再加上状态矩阵 AA 乘状态向量 h(t)h(t),得到状态向量的导数 h(t)h'(t)
下方为观测方程:状态向量 h(t)h(t) 乘输出矩阵 CC,得到输出结果 y(t)y(t)
这里的时不变是因为 ABCDABCD 是固定的,这也是一种假设,而且是强假设。DD 没写是因为在许多实际系统中可以为 0。transformer 是没有这种假设的。

2. 时序离散化与 RNN

连续系统不方便计算机进行处理,完全连续并不存在,通常我们的采样也会间隔一定很短的时间,以近似连续。对于连续系统通常采用离散化展开,即沿时间轴拉长,公式和上方类似。导数改为了不同时刻角标,形成了递归过程。

ht=Aht1+Bxtyt=Cht\begin{align} & h_{t}= \overline{A}h_{t-1}+\overline{B}x_{t} \\ & y_{t}=Ch_{t} \end{align}

这里有一种方法可以从连续系统转换为离散系统的 ABC 参数的对应关系,这个方法叫"零阶保持(Zero-Order Hold,ZOH)"。

A=exp(ΔA)B=(ΔA)1(exp(ΔA)I)ΔB\begin{align} & \overline{A}=exp(\Delta{A}) \\ & \overline{B}={(\Delta{A})^{-1}}(\exp(\Delta{A})-I)\cdot{\Delta{B}} \end{align}

除此以外,还有很多离散化的方法,离散化主要是为了方便计算机进行处理,同时也是 Mamba 的一个技巧,通过这种方法,就转化成类似 RNN 的门控机制。

3. 并行化处理与 CNN

SSM 与 RNN 相比,除了时序建模,它最大的特点就是通过卷积实现了计算上的并行化。