Strassen's algorithm
In linear algebra, the Strassen algorithm, named after Volker Strassen, is an algorithm for matrix multiplication. It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices, but would be slower than the fastest known algorithms for extremely large matrices.
Julia
The multiplication is denoted by * <lang Julia> function Strassen(A::Matrix, B::Matrix)
n = size(A, 1) if n == 1 return A * B end @views A11 = A[1:n/2, 1:n/2] @views A12 = A[1:n/2, n/2+1:n] @views A21 = A[n/2+1:n, 1:n/2] @views A11 = A[n/2+1:n, n/2+1:n] @views B11 = B[1:n/2, 1:n/2] @views B12 = B[1:n/2, n/2+1:n] @views B21 = B[n/2+1:n, 1:n/2] @views B11 = B[n/2+1:n, n/2+1:n]
P1 = Strassen(A12 - A22, B21 + B22) P2 = Strassen(A11 + A22, B11 + B22) P3 = Strassen(A11 - A21, B11 + B12) P4 = Strassen(A11 + A12, B22) P5 = Strassen(A11, B12 - B22) P6 = Strassen(A22, B21 - B11) P7 = Strassen(A21 + A22, B11)
C11 = P1 + P2 - P4 + P6 C12 = P4 + P5 C21 = P6 + P7 C22 = P2 - P3 + P5 - P7
return [C11 C12; C21 C22]
end</lang>
Phix
<lang Phix>function strassen(sequence a, b)
integer l = length(a) if length(a[1])!=l or length(b)!=l or length(b[1])!=l then crash("two equal square matrices only") end if if l=1 then return sq_mul(a,b) end if if remainder(l,1) then crash("2^n matrices only") end if integer h = l/2 sequence {a11,a12,a21,a22,b11,b12,b21,b22} @= repeat(repeat(0,h),h) for i=1 to h do for j=1 to h do a11[i][j] = a[i][j] a12[i][j] = a[i][j+h] a21[i][j] = a[i+h][j] a22[i][j] = a[i+h][j+h] b11[i][j] = b[i][j] b12[i][j] = b[i][j+h] b21[i][j] = b[i+h][j] b22[i][j] = b[i+h][j+h] end for end for sequence p1 = strassen(sq_sub(a12,a22), sq_add(b21,b22)), p2 = strassen(sq_add(a11,a22), sq_add(b11,b22)), p3 = strassen(sq_sub(a11,a21), sq_add(b11,b12)), p4 = strassen(sq_add(a11,a12), b22), p5 = strassen(a11, sq_sub(b12,b22)), p6 = strassen(a22, sq_sub(b21,b11)), p7 = strassen(sq_add(a21,a22), b11), c11 = sq_add(sq_sub(sq_add(p1,p2),p4),p6), c12 = sq_add(p4,p5), c21 = sq_add(p6,p7), c22 = sq_sub(sq_add(sq_sub(p2,p3),p5),p7), c = repeat(repeat(0,l),l) for i=1 to h do for j=1 to h do c[i][j] = c11[i][j] c[i][j+h] = c12[i][j] c[i+h][j] = c21[i][j] c[i+h][j+h] = c22[i][j] end for end for return c
end function
ppOpt({pp_Nest,1,pp_IntFmt,"%3d",pp_FltFmt,"%3.0f",pp_IntCh,false})
constant A = {{1,2},
{3,4}}, B = {{5,6}, {7,8}}
pp(strassen(A,B))
constant C = { { 1, 1, 1, 1 },
{ 2, 4, 8, 16 }, { 3, 9, 27, 81 }, { 4, 16, 64, 256 }}, D = { { 4, -3, 4/3, -1/ 4 }, {-13/3, 19/4, -7/3, 11/24 }, { 3/2, -2, 7/6, -1/ 4 }, { -1/6, 1/4, -1/6, 1/24 }}
pp(strassen(C,D))
constant E = {{ 1, 2, 3, 4},
{ 5, 6, 7, 8}, { 9,10,11,12}, {13,14,15,16}}, F = {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}}
pp(strassen(E,F))
constant r = sqrt(2)/2,
R = {{ r,r}, {-r,r}}
pp(strassen(R,R))</lang>
- Output:
{{ 19, 22}, { 43, 50}} {{ 1, 0, 0, 0}, { 0, 1, 0, 0}, { 0, 0, 1, 0}, { 0, 0, 0, 1}} {{ 1, 2, 3, 4}, { 5, 6, 7, 8}, { 9, 10, 11, 12}, { 13, 14, 15, 16}} {{ 0, 1}, { -1, 0}}