Strassen's algorithm

From Rosetta Code
Revision as of 14:33, 25 September 2020 by PureFox (talk | contribs) (Converted this to a draft task (presumably it's intended to be one?).)
Strassen's algorithm is a draft programming task. It is not yet considered ready to be promoted as a complete task, for reasons that should be found in its talk page.
Description

In linear algebra, the Strassen algorithm, named after Volker Strassen, is an algorithm for matrix multiplication. It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices, but would be slower than the fastest known algorithms for extremely large matrices.

Task

Write a routine, function, procedure etc. in your language to implement the Strassen algorithm for matrix multiplication.

See also


Julia

The multiplication is denoted by * <lang Julia> function Strassen(A::Matrix, B::Matrix)

   n = size(A, 1)
   if n == 1
       return A * B
   end
   @views A11 = A[1:n/2, 1:n/2]
   @views A12 = A[1:n/2, n/2+1:n]
   @views A21 = A[n/2+1:n, 1:n/2]
   @views A11 = A[n/2+1:n, n/2+1:n]
   @views B11 = B[1:n/2, 1:n/2]
   @views B12 = B[1:n/2, n/2+1:n]
   @views B21 = B[n/2+1:n, 1:n/2]
   @views B11 = B[n/2+1:n, n/2+1:n]
   P1 = Strassen(A12 - A22, B21 + B22)
   P2 = Strassen(A11 + A22, B11 + B22)
   P3 = Strassen(A11 - A21, B11 + B12)
   P4 = Strassen(A11 + A12, B22)
   P5 = Strassen(A11, B12 - B22)
   P6 = Strassen(A22, B21 - B11)
   P7 = Strassen(A21 + A22, B11)
   C11 = P1 + P2 - P4 + P6
   C12 = P4 + P5
   C21 = P6 + P7
   C22 = P2 - P3 + P5 - P7
   return [C11 C12; C21 C22]

end</lang>

Phix

As noted on wp, you could pad with zeroes, and strip them on exit, instead of crashing for non-square 2n matrices. <lang Phix>function strassen(sequence a, b)

   integer l = length(a)
   if length(a[1])!=l
   or length(b)!=l
   or length(b[1])!=l then
       crash("two equal square matrices only")
   end if
   if l=1 then return sq_mul(a,b) end if
   if remainder(l,1) then
       crash("2^n matrices only")
   end if
   integer h = l/2
   sequence {a11,a12,a21,a22,b11,b12,b21,b22} @= repeat(repeat(0,h),h)
   for i=1 to h do
       for j=1 to h do
           a11[i][j] = a[i][j]
           a12[i][j] = a[i][j+h]
           a21[i][j] = a[i+h][j]
           a22[i][j] = a[i+h][j+h]
           b11[i][j] = b[i][j]
           b12[i][j] = b[i][j+h]
           b21[i][j] = b[i+h][j]
           b22[i][j] = b[i+h][j+h]
       end for
   end for
   sequence p1 = strassen(sq_sub(a12,a22), sq_add(b21,b22)),
            p2 = strassen(sq_add(a11,a22), sq_add(b11,b22)),
            p3 = strassen(sq_sub(a11,a21), sq_add(b11,b12)),
            p4 = strassen(sq_add(a11,a12), b22),
            p5 = strassen(a11, sq_sub(b12,b22)),
            p6 = strassen(a22, sq_sub(b21,b11)),
            p7 = strassen(sq_add(a21,a22), b11),

            c11 = sq_add(sq_sub(sq_add(p1,p2),p4),p6),
            c12 = sq_add(p4,p5),
            c21 = sq_add(p6,p7),
            c22 = sq_sub(sq_add(sq_sub(p2,p3),p5),p7),
            c = repeat(repeat(0,l),l)
   for i=1 to h do
       for j=1 to h do
           c[i][j] = c11[i][j]
           c[i][j+h] = c12[i][j]
           c[i+h][j] = c21[i][j]
           c[i+h][j+h] = c22[i][j]
       end for
   end for
   return c

end function

ppOpt({pp_Nest,1,pp_IntFmt,"%3d",pp_FltFmt,"%3.0f",pp_IntCh,false})

constant A = {{1,2},

             {3,4}},
        B = {{5,6},
             {7,8}}

pp(strassen(A,B))

constant C = { { 1, 1, 1, 1 },

              { 2,  4,  8,  16 },
              { 3,  9, 27,  81 },
              { 4, 16, 64, 256 }},
        D = { {    4,   -3,  4/3, -1/ 4 },
              {-13/3, 19/4, -7/3, 11/24 },
              {  3/2,   -2,  7/6, -1/ 4 },
              { -1/6,  1/4, -1/6,  1/24 }}

pp(strassen(C,D))

constant E = {{ 1, 2, 3, 4},

             { 5, 6, 7, 8},
             { 9,10,11,12},
             {13,14,15,16}},
        F = {{1, 0, 0, 0},
             {0, 1, 0, 0},
             {0, 0, 1, 0},
             {0, 0, 0, 1}}

pp(strassen(E,F))

constant r = sqrt(2)/2,

        R = {{ r,r},
             {-r,r}}

pp(strassen(R,R))</lang>

Output:
{{ 19, 22},
 { 43, 50}}
{{  1,  0,  0,  0},
 {  0,  1,  0,  0},
 {  0,  0,  1,  0},
 {  0,  0,  0,  1}}
{{  1,  2,  3,  4},
 {  5,  6,  7,  8},
 {  9, 10, 11, 12},
 { 13, 14, 15, 16}}
{{  0,  1},
 { -1,  0}}