并行与分布式系统中的并行随机算法:并行化Sherman-Morrison公式及其在矩阵求逆中的应用
1. 算法背景与问题描述
在许多科学计算和工程应用中,经常需要求解线性方程组 \(Ax = b\) 或计算矩阵的逆 \(A^{-1}\)。当矩阵 \(A\) 是一个大规模稠密矩阵时,直接计算其逆矩阵的复杂度很高(通常为 \(O(n^3)\))。Sherman-Morrison公式提供了一种在已知矩阵 \(A\) 的逆 \(A^{-1}\) 的情况下,当 \(A\) 发生一个低秩更新(例如,\(A\) 加上一个外积 \(uv^T\),其中 \(u\) 和 \(v\) 是列向量)时,高效计算新矩阵 \((A + uv^T)^{-1}\) 的方法。其公式如下:
\[(A + uv^T)^{-1} = A^{-1} - \frac{A^{-1}uv^TA^{-1}}{1 + v^T A^{-1} u} \]
这个公式的原始计算复杂度是 \(O(n^2)\),远低于重新求逆的 \(O(n^3)\)。
然而,在并行与分布式环境中,当矩阵规模极大,或者需要处理一系列连续的低秩更新(即 \(A\) 被多次更新为 \(A + u_1v_1^T + u_2v_2^T + ...\))时,如何高效、并行地应用Sherman-Morrison公式(或其推广形式Woodbury矩阵恒等式)就成为了一个挑战。我们的目标是设计一个并行算法,能够利用多个处理器(或多台机器)协同工作,加速连续低秩更新下矩阵逆的迭代计算过程。
2. 核心思路与公式推导
首先,我们明确要解决的核心计算问题:
已知 \(A_0^{-1}\),以及一系列秩为1的更新对 \(\{(u_k, v_k)\}_{k=1}^m\)。我们需要计算:
\[A_k = A_{k-1} + u_k v_k^T, \quad k = 1, ..., m \]
并最终得到 \(A_m^{-1}\)。
串行算法是直接连续应用m次Sherman-Morrison公式:
- 初始化 \(M_0 = A_0^{-1}\)。
- 对于 \(k=1\) 到 \(m\):
a. 计算标量 \(c_k = 1 + v_k^T M_{k-1} u_k\)。
b. 计算向量 \(w_k = M_{k-1} u_k\)。
c. 计算向量 \(z_k^T = v_k^T M_{k-1}\) (或者等价地,\(z_k = M_{k-1}^T v_k\),如果 \(M_{k-1}\) 对称)。
d. 更新逆矩阵:\(M_k = M_{k-1} - (w_k z_k^T) / c_k\)。
这里的主要计算瓶颈在于步骤b、c、d中的矩阵-向量乘法和秩1矩阵更新,每一步的复杂度为 \(O(n^2)\)。串行总复杂度为 \(O(mn^2)\)。
并行化的关键在于:能否将多个更新对的计算进行某种“组合”或“聚合”,减少顺序依赖,从而让不同的处理器能同时处理不同的更新对,最后再将结果正确合并?
3. 并行化策略:基于分块与聚合的并行Sherman-Morrison算法
一个有效的并行策略是利用Woodbury矩阵恒等式的推广形式,将多个秩1更新“批量”处理。我们将m个更新对组合成两个矩阵:
令 \(U = [u_1, u_2, ..., u_m]\) 是一个 \(n \times m\) 矩阵。
令 \(V = [v_1, v_2, ..., v_m]\) 也是一个 \(n \times m\) 矩阵。
那么,总更新可以写为:
\[A_m = A_0 + UV^T \]
这称为一个秩为 \(m\) 的更新(尽管 \(m\) 可能远小于 \(n\))。推广的Woodbury公式(或称Sherman-Morrison-Woodbury公式)为:
\[(A_0 + UV^T)^{-1} = A_0^{-1} - A_0^{-1}U(I_m + V^T A_0^{-1}U)^{-1}V^T A_0^{-1} \]
其中 \(I_m\) 是 \(m \times m\) 的单位矩阵。
并行算法设计步骤如下:
步骤1:数据划分与任务分配
- 假设有 \(p\) 个处理器。
- 将初始逆矩阵 \(M_0 = A_0^{-1}\) 按行分块(或按二维块循环分块,取决于通信模式),每个处理器存储一部分行块,记为 \(M_0^{(i)}\)。
- 将更新矩阵 \(U\) 和 \(V\) 也按列分块(因为每一列对应一个更新对)。具体地,将m个更新对分成 \(p\) 组,每组大约 \(m/p\) 个更新对,分别构成 \(U^{(i)}\) 和 \(V^{(i)}\),并分配给第 \(i\) 个处理器。
步骤2:本地矩阵-矩阵乘法(高度并行)
- 每个处理器 \(i\) 并行计算两个局部结果:
- \(W^{(i)} = M_0^{(i)} U^{(i)}\)。 这里,\(M_0^{(i)}\) 是 \(n_i \times n\) 的行块(\(n_i\) 是该处理器负责的行数),\(U^{(i)}\) 是 \(n \times m_i\),结果 \(W^{(i)}\) 是 \(n_i \times m_i\)。
- \(S^{(i)} = (V^{(i)})^T M_0^{(i)}\)。 由于 \(M_0\) 通常对称(在很多应用场景下A是对称的),我们可以利用这一点,或者直接计算 \(Z^{(i)} = M_0^{(i)} V^{(i)}\),则 \(S^{(i)} = (Z^{(i)})^T\)。这一步需要计算 \(V^{(i)}\) 与本地 \(M_0^{(i)}\) 的乘积。
注意:为了计算核心公式中的 \(V^T M_0 U\)(它是一个 \(m \times m\) 的矩阵),我们需要全局聚合。观察 \(V^T M_0 U = \sum_{i=1}^p (V^{(i)})^T M_0^{(i)} U\)。但 \(U\) 是完整的。一个更高效、通信量更少的方法是:
- 每个处理器 \(i\) 先计算 \(X^{(i)} = (V^{(i)})^T W^{(i)}\)。因为 \(W^{(i)} = M_0^{(i)} U^{(i)}\),而 \(U^{(i)}\) 只是 \(U\) 的一部分,所以 \(X^{(i)}\) 是 \(m_i \times m_i\) 的块对角部分,并不能得到完整的 \(V^T M_0 U\)。
- 实际上,为了得到完整的 \(m \times m\) 矩阵 \(C = V^T M_0 U\),我们需要全局归约(All-Reduce)操作。每个处理器 \(i\) 计算它对整个矩阵 \(C\) 的部分贡献。这需要将本地的 \(W^{(i)}\) 和 \(V^{(i)}\) 的信息与其他处理器共享。
步骤3:全局聚合以构建核心矩阵(需要通信)
- 计算 \(C = V^T M_0 U\)。这可以通过两步完成:
a. 每个处理器 \(i\) 计算其对 \(C\) 的贡献:\(C^{(i)} = (V^{(i)})^T (M_0 U)\)。但 \(M_0 U\) 是全局的。更可行的方案是执行一个全收集(All-Gather)操作,让每个处理器获得完整的 \(W = M_0 U\)。由于 \(W\) 是 \(n \times m\),当 \(m\) 不大时,这个通信开销是可接受的。
b. 所有处理器进行 All-Gather 操作,交换它们计算出的 \(W^{(i)}\) 行块,使每个处理器获得完整的 \(W\)。
c. 然后,每个处理器 \(i\) 利用本地的 \(V^{(i)}\) 和完整的 \(W\),计算 \(C^{(i)} = (V^{(i)})^T W\)。注意,\(C^{(i)}\) 是 \(m_i \times m\) 的。
d. 通过一个归约-广播(Reduce-Scatter 或 All-Reduce)操作,将所有 \(C^{(i)}\) 按行(对应 \(V^{(i)}\) 的列索引)求和,得到完整的 \(m \times m\) 矩阵 \(C\),并分发给所有处理器(或每个处理器得到一块,但后续求逆需要完整的 \(C\),所以通常选择 All-Reduce 让每个处理器都有完整的 \(C\))。
步骤4:求解核心小矩阵的逆(可并行复制计算)
- 每个处理器(或指定一个主处理器)计算一个小矩阵的逆:\(D = (I_m + C)^{-1}\)。由于 \(m\) 通常远小于 \(n\),这是一个 \(O(m^3)\) 的计算,可以由每个处理器独立计算(因为上一步通过 All-Reduce 使得所有处理器都拥有 \(C\)),这是一个“冗余计算”,但成本低。或者由一个处理器计算后广播结果。
步骤5:并行计算最终更新项
- 计算最终更新项 \(M_0 U D V^T M_0\)。这可以分解为几个矩阵乘法,并利用已有的分块进行并行计算。
a. 计算 \(Y = W D\),其中 \(W = M_0U\) 是 \(n \times m\),\(D\) 是 \(m \times m\)。由于每个处理器拥有完整的 \(W\)(来自步骤3b)和完整的 \(D\),可以独立计算本行块对应的 \(Y^{(i)} = W^{(i)} D\)。这里 \(W^{(i)}\) 是处理器 \(i\) 负责的 \(n_i\) 行,所以 \(Y^{(i)}\) 也是 \(n_i \times m\)。
b. 计算 \(Y V^T = (W D) V^T\)。注意,我们需要从 \(Y V^T\) 中减去以更新 \(M_0\)。由于 \(V\) 是按列分块存储的,我们需要计算 \(Y (V^T) = \sum_{j=1}^p Y^{(i)} (V^{(j)})^T\)。这又是一个需要通信的操作。
c. 一个高效的策略是:将 \(Y\) 按行分块(我们已经有了 \(Y^{(i)}\)),将 \(V^T\) 按行分块(对应 \(V\) 的列分块,每个处理器 \(j\) 拥有 \(V^{(j)}\),其转置是 \(m_j \times n\) 的块)。为了计算全局的 \(Y V^T\),处理器 \(i\) 需要与所有处理器 \(j\) 通信,获取 \(V^{(j)}\) 来形成本地的部分和。这可以通过 All-Gather 将 \(V\) 收集到所有处理器上来实现(当 \(m\) 不大,且 \(V\) 是 \(n \times m\) 时可行),或者通过更复杂的多对多通信。
d. 更简洁的方法:注意到最终更新是对 \(M_0\) 的一个修正。每个处理器只需要更新它负责的 \(M_0\) 的行块。处理器 \(i\) 负责的最终逆矩阵块为:
\[ M_m^{(i)} = M_0^{(i)} - Y^{(i)} (V^T M_0^{(i)}) \]
这里 $Y^{(i)} = W^{(i)} D$ 是已知的。而 $V^T M_0^{(i)}$ 可以这样计算:由于 $M_0^{(i)}$ 是行块,$V^T$ 是 $m \times n$。我们可以再次利用 All-Gather 获取完整的 $V$(或 $V^T$),然后本地计算 $V^T M_0^{(i)}$,得到一个 $m \times n_i$ 的矩阵(注意,这里的 $n_i$ 是列数,对应于 $M_0^{(i)}$ 的列块,但 $M_0^{(i)}$ 是一个瘦高的行块,其列数是完整的 $n$)。实际上,$V^T M_0^{(i)}$ 是 $m \times n$ 矩阵的一个列块。这仍然需要完整的 $V^T$ 信息。
考虑到通信复杂性,一个更实际且常用的并行化方法是在共享内存环境下(如多核CPU),利用OpenMP或Pthreads进行并行化:
4. 共享内存多核并行化方案
在共享内存系统中,所有处理器核心共享同一份内存,可以访问完整的矩阵 \(M_0, U, V\)。这大大简化了通信问题。
步骤1:将m个更新对划分为若干个任务块。
- 将更新索引 \(1, 2, ..., m\) 划分为 \(t\) 个块(\(t\) 可以是线程数或略多),每个块包含一组连续的更新对。
步骤2:并行计算中间矩阵 \(W = M_0 U\) 和 \(Z = M_0 V\)。
- 由于 \(M_0\) 是 \(n \times n\),\(U\) 和 \(V\) 是 \(n \times m\),计算 \(W = M_0 U\) 和 \(Z = M_0 V\) 是标准的矩阵-矩阵乘法(Level 3 BLAS)。我们可以使用高度优化的并行BLAS库(如Intel MKL, OpenBLAS)来并行计算这两个乘积。这些库内部会使用分块、循环展开等技术,并利用多线程并行计算。
步骤3:并行计算核心矩阵 \(C = V^T W\)。
- 注意到 \(C = V^T W = (M_0 V)^T (M_0 U) = Z^T W\)。由于 \(Z\) 和 \(W\) 已经计算出,且都是 \(n \times m\),计算 \(C = Z^T W\) 也是一个矩阵乘法(\(m \times n\) 乘以 \(n \times m\),得到 \(m \times m\))。同样,可以使用并行BLAS库的
GEMM例程来计算。
步骤4:计算小矩阵 \(D = (I + C)^{-1}\)。
- 由于 \(m\) 较小,可以在单个线程上串行计算其逆,或者使用LAPACK的
GETRF(LU分解)和GETRI(求逆)例程。也可以使用多线程LAPACK,但开销可能不显著。
步骤5:并行计算最终更新。
- 计算 \(Y = W D\) (\(n \times m\) 乘以 \(m \times m\)),同样使用并行
GEMM。 - 计算最终更新 \(M_m = M_0 - Y Z^T\)。这是两个矩阵的相减,其中 \(Y Z^T\) 是一个秩为 \(m\) 的矩阵。我们可以并行地对 \(M_0\) 的各个元素进行更新。具体地,可以按行或按列分块,每个线程负责更新 \(M_0\) 的一个连续块。公式为 \(M_m = M_0 - Y Z^T\),这是一个 \(n \times n\) 的矩阵的秩 \(m\) 更新,也可以看作一个矩阵乘法(\(Y Z^T\))后接一个矩阵减法。并行BLAS的
GEMM可以高效计算 \(Y Z^T\),然后一个并行的矩阵减法(或使用AXPY类的操作)即可完成。
5. 算法复杂度与优化
- 时间复杂度:串行算法的复杂度为 \(O(mn^2)\)。在并行版本中,步骤2、3、5中的矩阵乘法是主要成本,理想情况下,在 \(p\) 个处理器上,这些矩阵乘法可以近似达到 \(O(n^2 m / p)\) 的效率(如果 \(n\) 远大于 \(m\) 和 \(p\))。计算 \(D\) 的 \(O(m^3)\) 步骤通常是次要的。
- 空间复杂度:需要存储 \(M_0, U, V, W, Z, C, D, Y\) 等矩阵。主要开销是 \(M_0\) 的 \(O(n^2)\) 和 \(W, Z, Y\) 的 \(O(nm)\)。在分布式内存中,需要仔细分块以适应当地内存。
- 优化要点:
- 当 \(m\) 非常大时(接近 \(n\)),Woodbury公式的优势会减弱,直接重新求逆可能更划算。因此,此算法适用于 \(m \ll n\) 的情况。
- 在分布式内存实现中,如果 \(m\) 很小,步骤3和5中的 All-Gather 操作通信数据量为 \(O(nm)\),如果 \(nm\) 远小于 \(n^2\),则是可接受的。需要权衡计算收益和通信开销。
- 对于一系列连续的更新,可以采用“延迟更新”或“批量处理”策略,积累一定数量的更新对 \((u_k, v_k)\) 后,再一次性使用上述批量并行算法,而不是逐个更新,这样可以摊薄通信和同步开销。
总结
并行化Sherman-Morrison公式的核心思想是将多个连续的秩1更新聚合成一个秩为 \(m\) 的更新,然后利用Woodbury公式,将计算主体转化为一系列可以并行执行的稠密矩阵乘法运算。在共享内存系统中,可以充分利用多线程BLAS库实现高效并行。在分布式内存系统中,需要对矩阵进行分块,并通过集合通信操作(如All-Gather, All-Reduce)来协调各处理器间的数据,在 \(m\) 较小、更新批量足够大时,能有效加速大规模矩阵在低秩修正后的求逆更新过程。