博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Strassen矩阵算法分析及其C++实现 递归分治法(转)
阅读量:6092 次
发布时间:2019-06-20

本文共 2715 字,大约阅读时间需要 9 分钟。

对于矩阵乘法 C = A × B,通常的做法是将矩阵进行分块相乘,如下图所示:

从上图可以看出这种分块相乘总共用了8次乘法,当然对于子矩阵相乘(如A0×B0),还可以继续递归使用分块相乘。对于中小矩阵来说,很适合使用这种分块乘法,但是对于大矩阵来说,递归的次数较多,如果能减少每次分块乘法的次数,那么性能将可以得到很好的提高。

Strassen矩阵乘法就是采用了一个简单的运算技巧,将上面的8次矩阵相乘变成了7次乘法,看别小看这减少的1次乘法,因为每递归1次,性能就提高了1/8,比如对于1024*1024的矩阵,第1次先分解成7次512*512的矩阵相乘,对于512*512的矩阵,又可以继续递归分解成256*256的矩阵相乘,…,一直递归下去,假设分解到64*64的矩阵大小后就不再递归,那么所花的时间将是分块矩阵乘法的(7/8) * (7/8) * (7/8) * (7/8) = 0.586倍,提高了快接近一倍。当然这是理论上的值,因为实际上strassen乘法增加了其他运算开销,实际性能会略低一点。

由上可见,Strassen矩阵乘法是通过递归实现的,它将一般情况下二阶矩阵乘法(可扩展到n阶,但Strassen矩阵乘法要求n是2的幂)所需的8次乘法降低为7次,其C++实现代码如下:

下面就是Strassen矩阵乘法的实现方法,

    M1 = (A0 + A3) × (B0 + B3)

   M2 = (A2 + A3) × B0

    M3 = A0 × (B1 - B3)

    M4 = A3 × (B2 - B0)

    M5 = (A0 + A1) × B3

    M6 = (A2 - A0) × (B0 + B1)

    M7 = (A1 - A3) × (B2 + B3)

    C0 = M1 + M4 - M5 + M7

    C1 = M3 + M5

    C2 = M2 + M4

    C3 = M1 - M2 + M3 + M6

在求解M1,M2,M3,M4,M5,M6,M7时需要使用7次矩阵乘法,其他都是矩阵加法和减法。

下面看看Strassen矩阵乘法的串行实现伪代码:

Serial_StrassenMultiply(A, B, C)

{

    T1 = A0 + A3;

    T2 = B0 + B3;

    StrassenMultiply(T1, T2, M1);

    T1 = A2 + A3;

    StrassenMultiply(T1, B0, M2);

    T1 = (B1 - B3);

    StrassenMultiply (A0, T1, M3);

 

    T1 = B2 - B0;

    StrassenMultiply(A3, T1, M4);

 

   T1 = A0 + A1;

   StrassenMultiply(T1, B3, M5);       

   

    T1 = A2 – A0;

    T2 = B0 + B1;

    StrassenMultiply(T1, T2, M6);

    T1 = A1 – A3;

    T2 = B2 + B3;

    StrassenMultiply(T1, T2, M7);

    C0 = M1 + M4 - M5 + M7

    C1 = M3 + M5

    C2 = M2 + M4

    C3 = M1 - M2 + M3 + M6

}

 

#include 
using namespace std; const int N = 6; //Define the size of the Matrix template
void Strassen(int n, T A[][N], T B[][N], T C[][N]); template
void input(int n, T p[][N]); template
void output(int n, T C[][N]); int main() { //Define three Matrices int A[N][N],B[N][N],C[N][N]; //对A和B矩阵赋值,随便赋值都可以,测试用 for(int i=0; i
void input(int n, T p[][N]) { for(int i=0; i
>p[i][j]; } } } template
void output(int n, T C[][N]) { cout<<"The Output Matrix is :"<
void Matrix_Multiply(T A[][N], T B[][N], T C[][N]) { //Calculating A*B->C for(int i=0; i<2; i++) { for(int j=0; j<2; j++) { C[i][j] = 0; for(int t=0; t<2; t++) { C[i][j] = C[i][j] + A[i][t]*B[t][j]; } } } } template
void Matrix_Add(int n, T X[][N], T Y[][N], T Z[][N]) { for(int i=0; i
void Matrix_Sub(int n, T X[][N], T Y[][N], T Z[][N]) { for(int i=0; i
void Strassen(int n, T A[][N], T B[][N], T C[][N]) { T A11[N][N], A12[N][N], A21[N][N], A22[N][N]; T B11[N][N], B12[N][N], B21[N][N], B22[N][N]; T C11[N][N], C12[N][N], C21[N][N], C22[N][N]; T M1[N][N], M2[N][N], M3[N][N], M4[N][N], M5[N][N], M6[N][N], M7[N][N]; T AA[N][N], BB[N][N]; if(n == 2) { //2-order Matrix_Multiply(A, B, C); } else { //将矩阵A和B分成阶数相同的四个子矩阵,即分治思想。 for(int i=0; i

//原文请看

转载地址:http://wbmwa.baihongyu.com/

你可能感兴趣的文章
[译] 可维护的 ETL:使管道更容易支持和扩展的技巧
查看>>
### 继承 ###
查看>>
数组扩展方法之求和
查看>>
astah-professional-7_2_0安装
查看>>
函数是对象-有属性有方法
查看>>
uva 10107 - What is the Median?
查看>>
Linux下基本栈溢出攻击【转】
查看>>
c# 连等算式都在做什么
查看>>
使用c:forEach 控制5个换行
查看>>
java web轻量级开发面试教程摘录,java web面试技巧汇总,如何准备Spring MVC方面的面试...
查看>>
使用ansible工具部署ceph
查看>>
linux系列博文---->深入理解linux启动运行原理(一)
查看>>
Android反编译(一) 之反编译JAVA源码
查看>>
结合当前公司发展情况,技术团队情况,设计一个适合的技术团队绩效考核机制...
查看>>
python-45: opener 的使用
查看>>
cad图纸转换完成的pdf格式模糊应该如何操作?
查看>>
Struts2与Struts1区别
查看>>
网站内容禁止复制解决办法
查看>>
Qt多线程
查看>>
我的友情链接
查看>>