Strassen's algorithm

Revision as of 18:19, 24 September 2020 by Petelomax (talk | contribs) (→‎{{header|Phix}}: added a zero padding comment)

In linear algebra, the Strassen algorithm, named after Volker Strassen, is an algorithm for matrix multiplication. It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices, but would be slower than the fastest known algorithms for extremely large matrices.

Julia

The multiplication is denoted by * <lang Julia> function Strassen(A::Matrix, B::Matrix)

   n = size(A, 1)
   if n == 1
       return A * B
   end
   @views A11 = A[1:n/2, 1:n/2]
   @views A12 = A[1:n/2, n/2+1:n]
   @views A21 = A[n/2+1:n, 1:n/2]
   @views A11 = A[n/2+1:n, n/2+1:n]
   @views B11 = B[1:n/2, 1:n/2]
   @views B12 = B[1:n/2, n/2+1:n]
   @views B21 = B[n/2+1:n, 1:n/2]
   @views B11 = B[n/2+1:n, n/2+1:n]
   P1 = Strassen(A12 - A22, B21 + B22)
   P2 = Strassen(A11 + A22, B11 + B22)
   P3 = Strassen(A11 - A21, B11 + B12)
   P4 = Strassen(A11 + A12, B22)
   P5 = Strassen(A11, B12 - B22)
   P6 = Strassen(A22, B21 - B11)
   P7 = Strassen(A21 + A22, B11)
   C11 = P1 + P2 - P4 + P6
   C12 = P4 + P5
   C21 = P6 + P7
   C22 = P2 - P3 + P5 - P7
   return [C11 C12; C21 C22]

end</lang>

Phix

As noted on wp, you could pad with zeroes, and strip them on exit, instead of crashing for non-square matrices. <lang Phix>function strassen(sequence a, b)

   integer l = length(a)
   if length(a[1])!=l
   or length(b)!=l
   or length(b[1])!=l then
       crash("two equal square matrices only")
   end if
   if l=1 then return sq_mul(a,b) end if
   if remainder(l,1) then
       crash("2^n matrices only")
   end if
   integer h = l/2
   sequence {a11,a12,a21,a22,b11,b12,b21,b22} @= repeat(repeat(0,h),h)
   for i=1 to h do
       for j=1 to h do
           a11[i][j] = a[i][j]
           a12[i][j] = a[i][j+h]
           a21[i][j] = a[i+h][j]
           a22[i][j] = a[i+h][j+h]
           b11[i][j] = b[i][j]
           b12[i][j] = b[i][j+h]
           b21[i][j] = b[i+h][j]
           b22[i][j] = b[i+h][j+h]
       end for
   end for
   sequence p1 = strassen(sq_sub(a12,a22), sq_add(b21,b22)),
            p2 = strassen(sq_add(a11,a22), sq_add(b11,b22)),
            p3 = strassen(sq_sub(a11,a21), sq_add(b11,b12)),
            p4 = strassen(sq_add(a11,a12), b22),
            p5 = strassen(a11, sq_sub(b12,b22)),
            p6 = strassen(a22, sq_sub(b21,b11)),
            p7 = strassen(sq_add(a21,a22), b11),

            c11 = sq_add(sq_sub(sq_add(p1,p2),p4),p6),
            c12 = sq_add(p4,p5),
            c21 = sq_add(p6,p7),
            c22 = sq_sub(sq_add(sq_sub(p2,p3),p5),p7),
            c = repeat(repeat(0,l),l)
   for i=1 to h do
       for j=1 to h do
           c[i][j] = c11[i][j]
           c[i][j+h] = c12[i][j]
           c[i+h][j] = c21[i][j]
           c[i+h][j+h] = c22[i][j]
       end for
   end for
   return c

end function

ppOpt({pp_Nest,1,pp_IntFmt,"%3d",pp_FltFmt,"%3.0f",pp_IntCh,false})

constant A = {{1,2},

             {3,4}},
        B = {{5,6},
             {7,8}}

pp(strassen(A,B))

constant C = { { 1, 1, 1, 1 },

              { 2,  4,  8,  16 },
              { 3,  9, 27,  81 },
              { 4, 16, 64, 256 }},
        D = { {    4,   -3,  4/3, -1/ 4 },
              {-13/3, 19/4, -7/3, 11/24 },
              {  3/2,   -2,  7/6, -1/ 4 },
              { -1/6,  1/4, -1/6,  1/24 }}

pp(strassen(C,D))

constant E = {{ 1, 2, 3, 4},

             { 5, 6, 7, 8},
             { 9,10,11,12},
             {13,14,15,16}},
        F = {{1, 0, 0, 0},
             {0, 1, 0, 0},
             {0, 0, 1, 0},
             {0, 0, 0, 1}}

pp(strassen(E,F))

constant r = sqrt(2)/2,

        R = {{ r,r},
             {-r,r}}

pp(strassen(R,R))</lang>

Output:
{{ 19, 22},
 { 43, 50}}
{{  1,  0,  0,  0},
 {  0,  1,  0,  0},
 {  0,  0,  1,  0},
 {  0,  0,  0,  1}}
{{  1,  2,  3,  4},
 {  5,  6,  7,  8},
 {  9, 10, 11, 12},
 { 13, 14, 15, 16}}
{{  0,  1},
 { -1,  0}}