基于Transformer的图像分类算法:Swin Transformer
字数 2913 2025-12-18 22:56:10
基于Transformer的图像分类算法:Swin Transformer
题目描述
在计算机视觉领域,图像分类是一项基础任务,旨在将输入图像分配到一个预定义的类别标签中。随着Transformer架构在自然语言处理中的巨大成功,研究人员开始探索将其应用于视觉任务。然而,标准的视觉Transformer模型在处理高分辨率图像时,其自注意力机制的计算复杂度与图像尺寸的平方成正比,导致内存和计算开销巨大,难以应用于密集预测任务或作为通用视觉骨干网络。
本题目要求详细讲解Swin Transformer算法。它是由微软亚洲研究院在2021年提出的,旨在构建一个可以像卷积神经网络那样,作为通用视觉骨干网络使用的Transformer架构。其核心创新在于引入了层级结构和移位窗口,有效解决了计算复杂度问题,并在图像分类、目标检测、语义分割等多个视觉任务上取得了当时的领先性能。你需要理解其如何通过“窗口”和“移位”的巧妙设计,在局部计算自注意力以降低复杂度,同时通过层级构建实现多尺度特征表示。
解题过程循序渐进讲解
-
问题背景与动机
- 标准Vision Transformer的局限性:ViT将图像分割成一系列不重叠的Patch(图像块),然后将它们线性投影为Token序列送入Transformer编码器。然而,其自注意力机制是全局的,每个Token都需要与其他所有Token交互,计算复杂度为 O(N²),其中N是Token的数量。对于高分辨率图像(如用于检测、分割的任务),N会很大,导致计算成本过高。
- 卷积神经网络的优势:CNN通过局部连接(卷积核)、下采样(池化)和层级结构,高效地处理图像并构建多尺度特征图。Swin Transformer的目标是借鉴这些思想,设计一个具有类似归纳偏置的Transformer架构。
-
核心思想:层级结构与移位窗口
- 层级特征图:Swin Transformer构建了类似于CNN的金字塔结构。它从小的Patch Token开始(例如4x4像素为一个Patch),通过“Patch Merging”操作,在多个阶段逐渐合并相邻的Token,从而减少Token数量、扩大每个Token的感受野,最终形成多分辨率特征图(例如,1/4, 1/8, 1/16, 1/32原图大小)。这使得Swin Transformer可以直接替换CNN骨干网络(如ResNet),用于需要多尺度特征的下游任务。
- 基于窗口的自注意力:为了将全局自注意力的二次复杂度降低到线性,Swin Transformer将图像特征图均匀地划分为不重叠的、固定大小的窗口(例如,MxM个Patch)。自注意力计算仅在每个窗口内部独立进行。这样,计算复杂度就从全局的 O((HW)²) 降低为窗口内的 O(M² * (HW/M²)) = O(HW * M²)。其中H、W是特征图的高和宽(以Patch计),M是窗口大小。M是固定常数,因此复杂度变为与图像尺寸HW呈线性关系。
-
关键创新:移位窗口多头自注意力
- 窗口划分的限制:固定窗口划分虽然降低了计算量,但也完全割裂了不同窗口之间的信息交流,模型无法捕获窗口间的依赖关系,这会限制建模能力。
- 移位窗口方案:为了解决这个问题,Swin Transformer在连续的Transformer块中交替使用两种窗口划分策略。
- 规则窗口划分:第一个模块采用常规的均匀网格划分。
- 移位窗口划分:第二个模块将窗口从规则划分的位置,在水平和垂直方向上各移位
⌊M/2⌋个Patch。这样,新的窗口将包含来自上一层不同规则窗口的Token,从而在相邻窗口之间建立了连接。
- 高效批处理计算:移位后窗口大小不一致,且数量变多(从
(H/M) * (W/M)个变为(H/M+1) * (W/M+1)个),直接计算会导致效率低下。Swin Transformer采用了巧妙的循环移位+掩码机制:将移位后的特征图在边缘进行循环填充,然后仍按MxM大小进行均匀窗口划分。此时,有些窗口包含来自不相邻区域的Token,通过在自注意力计算中引入一个可学习的相对位置偏置和特定的注意力掩码,来屏蔽这些不应关联的Token对,确保自注意力只在物理空间上相邻的区域内部计算。计算完成后,再将循环移位的部分移回原处。这个方案使得所有窗口大小统一,可以利用高效的批处理进行计算。
-
网络架构详解
- Patch Partition & Linear Embedding:输入图像首先被分割成不重叠的4x4像素块,每个块的特征维度是4x4x3=48。然后通过一个线性投影层,将其映射到一个指定的维度C。
- Swin Transformer Block:这是基本构建模块,由两个子模块串联组成。
- 基于(移位)窗口的多头自注意力:如前所述,每个Block要么使用W-MSA(规则窗口划分),要么使用SW-MSA(移位窗口划分),二者交替出现。自注意力公式中加入了相对位置偏置B:
Attention(Q, K, V) = SoftMax((QK^T)/√d + B) V。 - 两层MLP:每个注意力层和MLP层之前都应用了层归一化,之后使用了残差连接,即模块结构为:
LayerNorm -> (W/SW-MSA) -> Residual -> LayerNorm -> MLP -> Residual。
- 基于(移位)窗口的多头自注意力:如前所述,每个Block要么使用W-MSA(规则窗口划分),要么使用SW-MSA(移位窗口划分),二者交替出现。自注意力公式中加入了相对位置偏置B:
- Stage构建:网络通常包含4个Stage。
- Stage 1:在Patch Merging之后,Token数量为
(H/4) * (W/4),特征维度为C。堆叠N1个Swin Transformer Block。 - Stage 2, 3, 4:每个Stage开始时执行Patch Merging。它将相邻的2x2个Patch的特征拼接起来(维度变为4C),然后通过一个线性层将其投影到2C维度,从而实现下采样(Token数量减半,特征维度翻倍)。然后堆叠Ni个Swin Transformer Block。
- Stage 1:在Patch Merging之后,Token数量为
- 输出:对于图像分类任务,在最后一个Stage输出的特征图上应用全局平均池化,然后接一个分类头(全连接层)。
-
主要优势与影响
- 线性计算复杂度:使Transformer能够处理高分辨率图像,成为实用的视觉骨干网络。
- 多尺度特征表示:层级结构使其能够生成适用于检测、分割等多种任务的通用特征图。
- 强大的建模能力:通过局部窗口注意力和窗口间的信息交换,有效地建模了视觉元素的局部性和层次性,在多项基准测试中取得了SOTA结果。
- 广泛的应用:Swin Transformer迅速成为计算机视觉领域的标志性工作之一,其设计思想被后续许多研究借鉴,广泛应用于分类、检测、分割、视频理解等任务。
通过以上步骤,Swin Transformer成功地将Transformer的计算复杂度从图像尺寸的平方降低到线性,并构建了层级化的多尺度特征表示,从而克服了标准ViT的主要缺点,确立了Transformer作为通用视觉骨干网络的可行性。