Strassen's algorithm: Difference between revisions

From Rosetta Code
Content added Content deleted
(Added Go)
Line 193: Line 193:


=={{header|Julia}}==
=={{header|Julia}}==
{{incorrect|Julia|A11,B11 declared twice, A22,B22 undefined. No output.}}
The multiplication is denoted by *
The multiplication is denoted by *
<lang Julia>
<lang Julia>

Revision as of 18:46, 26 September 2020

Strassen's algorithm is a draft programming task. It is not yet considered ready to be promoted as a complete task, for reasons that should be found in its talk page.
Description

In linear algebra, the Strassen algorithm, named after Volker Strassen, is an algorithm for matrix multiplication. It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices, but would be slower than the fastest known algorithms for extremely large matrices.

Task

Write a routine, function, procedure etc. in your language to implement the Strassen algorithm for matrix multiplication.

See also


Go

Translation of: Wren

Rather than use a library such as gonum, we create a simple Matrix type which is adequate for this task. <lang go>package main

import (

   "fmt"
   "log"
   "math"

)

type Matrix [][]float64

func (m Matrix) rows() int { return len(m) } func (m Matrix) cols() int { return len(m[0]) }

func (m Matrix) add(m2 Matrix) Matrix {

   if m.rows() != m2.rows() || m.cols() != m2.cols() {
       log.Fatal("Matrices must have the same dimensions.")
   }
   c := make(Matrix, m.rows())
   for i := 0; i < m.rows(); i++ {
       c[i] = make([]float64, m.cols())
       for j := 0; j < m.cols(); j++ {
           c[i][j] = m[i][j] + m2[i][j]
       }
   }
   return c

}

func (m Matrix) sub(m2 Matrix) Matrix {

   if m.rows() != m2.rows() || m.cols() != m2.cols() {
       log.Fatal("Matrices must have the same dimensions.")
   }
   c := make(Matrix, m.rows())
   for i := 0; i < m.rows(); i++ {
       c[i] = make([]float64, m.cols())
       for j := 0; j < m.cols(); j++ {
           c[i][j] = m[i][j] - m2[i][j]
       }
   }
   return c

}

func (m Matrix) mul(m2 Matrix) Matrix {

   if m.cols() != m2.rows() {
       log.Fatal("Cannot multiply these matrices.")
   }
   c := make(Matrix, m.rows())
   for i := 0; i < m.rows(); i++ {
       c[i] = make([]float64, m2.cols())
       for j := 0; j < m2.cols(); j++ {
           for k := 0; k < m2.rows(); k++ {
               c[i][j] += m[i][k] * m2[k][j]
           }
       }
   }
   return c

}

func (m Matrix) toString(p int) string {

   s := make([]string, m.rows())
   pow := math.Pow(10, float64(p))
   for i := 0; i < m.rows(); i++ {
       t := make([]string, m.cols())
       for j := 0; j < m.cols(); j++ {
           r := math.Round(m[i][j]*pow) / pow
           t[j] = fmt.Sprintf("%g", r)
           if t[j] == "-0" {
               t[j] = "0"
           }
       }
       s[i] = fmt.Sprintf("%v", t)
   }
   return fmt.Sprintf("%v", s)

}

func params(r, c int) [4][6]int {

   return [4][6]int{
       {0, r, 0, c, 0, 0},
       {0, r, c, 2 * c, 0, c},
       {r, 2 * r, 0, c, r, 0},
       {r, 2 * r, c, 2 * c, r, c},
   }

}

func toQuarters(m Matrix) [4]Matrix {

   r := m.rows() / 2
   c := m.cols() / 2
   p := params(r, c)
   var quarters [4]Matrix
   for k := 0; k < 4; k++ {
       q := make(Matrix, r)
       for i := p[k][0]; i < p[k][1]; i++ {
           q[i-p[k][4]] = make([]float64, c)
           for j := p[k][2]; j < p[k][3]; j++ {
               q[i-p[k][4]][j-p[k][5]] = m[i][j]
           }
       }
       quarters[k] = q
   }
   return quarters

}

func fromQuarters(q [4]Matrix) Matrix {

   r := q[0].rows()
   c := q[0].cols()
   p := params(r, c)
   r *= 2
   c *= 2
   m := make(Matrix, r)
   for i := 0; i < c; i++ {
       m[i] = make([]float64, c)
   }
   for k := 0; k < 4; k++ {
       for i := p[k][0]; i < p[k][1]; i++ {
           for j := p[k][2]; j < p[k][3]; j++ {
               m[i][j] = q[k][i-p[k][4]][j-p[k][5]]
           }
       }
   }
   return m

}

func strassen(a, b Matrix) Matrix {

   if a.rows() != a.cols() || b.rows() != b.cols() || a.rows() != b.rows() {
       log.Fatal("Matrices must be square and of equal size.")
   }
   if a.rows() == 0 || (a.rows()&(a.rows()-1)) != 0 {
       log.Fatal("Size of matrices must be a power of two.")
   }
   if a.rows() == 1 {
       return a.mul(b)
   }
   qa := toQuarters(a)
   qb := toQuarters(b)
   p1 := strassen(qa[1].sub(qa[3]), qb[2].add(qb[3]))
   p2 := strassen(qa[0].add(qa[3]), qb[0].add(qb[3]))
   p3 := strassen(qa[0].sub(qa[2]), qb[0].add(qb[1]))
   p4 := strassen(qa[0].add(qa[1]), qb[3])
   p5 := strassen(qa[0], qb[1].sub(qb[3]))
   p6 := strassen(qa[3], qb[2].sub(qb[0]))
   p7 := strassen(qa[2].add(qa[3]), qb[0])
   var q [4]Matrix
   q[0] = p1.add(p2).sub(p4).add(p6)
   q[1] = p4.add(p5)
   q[2] = p6.add(p7)
   q[3] = p2.sub(p3).add(p5).sub(p7)
   return fromQuarters(q)

}

func main() {

   a := Matrix{{1, 2}, {3, 4}}
   b := Matrix{{5, 6}, {7, 8}}
   c := Matrix{{1, 1, 1, 1}, {2, 4, 8, 16}, {3, 9, 27, 81}, {4, 16, 64, 256}}
   d := Matrix{{4, -3, 4.0 / 3, -1.0 / 4}, {-13.0 / 3, 19.0 / 4, -7.0 / 3, 11.0 / 24},
       {3.0 / 2, -2, 7.0 / 6, -1.0 / 4}, {-1.0 / 6, 1.0 / 4, -1.0 / 6, 1.0 / 24}}
   e := Matrix{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}
   f := Matrix{{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}}
   fmt.Println("Using 'normal' matrix multiplication:")
   fmt.Printf("  a * b = %v\n", a.mul(b))
   fmt.Printf("  c * d = %v\n", c.mul(d).toString(6))
   fmt.Printf("  e * f = %v\n", e.mul(f))
   fmt.Println("\nUsing 'Strassen' matrix multiplication:")
   fmt.Printf("  a * b = %v\n", strassen(a, b))
   fmt.Printf("  c * d = %v\n", strassen(c, d).toString(6))
   fmt.Printf("  e * f = %v\n", strassen(e, f))

}</lang>

Output:
Using 'normal' matrix multiplication:
  a * b = [[19 22] [43 50]]
  c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]]
  e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]]

Using 'Strassen' matrix multiplication:
  a * b = [[19 22] [43 50]]
  c * d = [[1 0 0 0] [0 1 0 0] [0 0 1 0] [0 0 0 1]]
  e * f = [[1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16]]

Julia

This example is incorrect. Please fix the code and remove this message.

Details: A11,B11 declared twice, A22,B22 undefined. No output.

The multiplication is denoted by * <lang Julia> function Strassen(A::Matrix, B::Matrix)

   n = size(A, 1)
   if n == 1
       return A * B
   end
   @views A11 = A[1:n/2, 1:n/2]
   @views A12 = A[1:n/2, n/2+1:n]
   @views A21 = A[n/2+1:n, 1:n/2]
   @views A11 = A[n/2+1:n, n/2+1:n]
   @views B11 = B[1:n/2, 1:n/2]
   @views B12 = B[1:n/2, n/2+1:n]
   @views B21 = B[n/2+1:n, 1:n/2]
   @views B11 = B[n/2+1:n, n/2+1:n]
   P1 = Strassen(A12 - A22, B21 + B22)
   P2 = Strassen(A11 + A22, B11 + B22)
   P3 = Strassen(A11 - A21, B11 + B12)
   P4 = Strassen(A11 + A12, B22)
   P5 = Strassen(A11, B12 - B22)
   P6 = Strassen(A22, B21 - B11)
   P7 = Strassen(A21 + A22, B11)
   C11 = P1 + P2 - P4 + P6
   C12 = P4 + P5
   C21 = P6 + P7
   C22 = P2 - P3 + P5 - P7
   return [C11 C12; C21 C22]

end</lang>

Phix

As noted on wp, you could pad with zeroes, and strip them on exit, instead of crashing for non-square 2n matrices. <lang Phix>function strassen(sequence a, b)

   integer l = length(a)
   if length(a[1])!=l
   or length(b)!=l
   or length(b[1])!=l then
       crash("two equal square matrices only")
   end if
   if l=1 then return sq_mul(a,b) end if
   if remainder(l,1) then
       crash("2^n matrices only")
   end if
   integer h = l/2
   sequence {a11,a12,a21,a22,b11,b12,b21,b22} @= repeat(repeat(0,h),h)
   for i=1 to h do
       for j=1 to h do
           a11[i][j] = a[i][j]
           a12[i][j] = a[i][j+h]
           a21[i][j] = a[i+h][j]
           a22[i][j] = a[i+h][j+h]
           b11[i][j] = b[i][j]
           b12[i][j] = b[i][j+h]
           b21[i][j] = b[i+h][j]
           b22[i][j] = b[i+h][j+h]
       end for
   end for
   sequence p1 = strassen(sq_sub(a12,a22), sq_add(b21,b22)),
            p2 = strassen(sq_add(a11,a22), sq_add(b11,b22)),
            p3 = strassen(sq_sub(a11,a21), sq_add(b11,b12)),
            p4 = strassen(sq_add(a11,a12), b22),
            p5 = strassen(a11, sq_sub(b12,b22)),
            p6 = strassen(a22, sq_sub(b21,b11)),
            p7 = strassen(sq_add(a21,a22), b11),

            c11 = sq_add(sq_sub(sq_add(p1,p2),p4),p6),
            c12 = sq_add(p4,p5),
            c21 = sq_add(p6,p7),
            c22 = sq_sub(sq_add(sq_sub(p2,p3),p5),p7),
            c = repeat(repeat(0,l),l)
   for i=1 to h do
       for j=1 to h do
           c[i][j] = c11[i][j]
           c[i][j+h] = c12[i][j]
           c[i+h][j] = c21[i][j]
           c[i+h][j+h] = c22[i][j]
       end for
   end for
   return c

end function

ppOpt({pp_Nest,1,pp_IntFmt,"%3d",pp_FltFmt,"%3.0f",pp_IntCh,false})

constant A = {{1,2},

             {3,4}},
        B = {{5,6},
             {7,8}}

pp(strassen(A,B))

constant C = { { 1, 1, 1, 1 },

              { 2,  4,  8,  16 },
              { 3,  9, 27,  81 },
              { 4, 16, 64, 256 }},
        D = { {    4,   -3,  4/3, -1/ 4 },
              {-13/3, 19/4, -7/3, 11/24 },
              {  3/2,   -2,  7/6, -1/ 4 },
              { -1/6,  1/4, -1/6,  1/24 }}

pp(strassen(C,D))

constant E = {{ 1, 2, 3, 4},

             { 5, 6, 7, 8},
             { 9,10,11,12},
             {13,14,15,16}},
        F = {{1, 0, 0, 0},
             {0, 1, 0, 0},
             {0, 0, 1, 0},
             {0, 0, 0, 1}}

pp(strassen(E,F))

constant r = sqrt(2)/2,

        R = {{ r,r},
             {-r,r}}

pp(strassen(R,R))</lang>

Output:
{{ 19, 22},
 { 43, 50}}
{{  1,  0,  0,  0},
 {  0,  1,  0,  0},
 {  0,  0,  1,  0},
 {  0,  0,  0,  1}}
{{  1,  2,  3,  4},
 {  5,  6,  7,  8},
 {  9, 10, 11, 12},
 { 13, 14, 15, 16}}
{{  0,  1},
 { -1,  0}}

Wren

Library: Wren-fmt

Wren doesn't currently have a matrix module so I've written a rudimentary Matrix class with sufficient functionality to complete this task.

I've used the Phix entry's examples to test the Strassen algorithm implementation. <lang ecmascript>import "/fmt" for Fmt

class Matrix {

   construct new(a) {
        if (a.type != List || a.count == 0 || a[0].type != List || a[0].count == 0 || a[0][0].type != Num) {
           Fiber.abort("Argument must be a non-empty two dimensional list of numbers.")
        }
        _a  = a
   }
   rows { _a.count }
   cols { _a[0].count }
   +(b) {
       if (b.type != Matrix) Fiber.abort("Argument must be a matrix.")
       if ((this.rows != b.rows) || (this.cols != b.cols)) {
           Fiber.abort("Matrices must have the same dimensions.")
       }
       var c = List.filled(rows, null)
       for (i in 0...rows) {
           c[i] = List.filled(cols, 0)
           for (j in 0...cols) c[i][j] = _a[i][j] + b[i, j]
       }
       return Matrix.new(c)
   }
   - { this * -1 }
   -(b) { this + (-b) }
   *(b) {
       var c = List.filled(rows, null)
       if (b is Num) {
           for (i in 0...rows) {
               c[i] = List.filled(cols, 0)
               for (j in 0...cols) c[i][j] = _a[i][j] * b
           }
       } else if (b is Matrix) {
           if (this.cols != b.rows) Fiber.abort("Cannot multiply these matrices.")
           for (i in 0...rows) {
               c[i] = List.filled(b.cols, 0)
               for (j in 0...b.cols) {
                   for (k in 0...b.rows) c[i][j] = c[i][j] + _a[i][k] * b[k, j]
               }
           }
       } else {
           Fiber.abort("Argument must be a matrix or a number.")
       }
       return Matrix.new(c)
   }
   [i] { _a[i].toList }
   [i, j] { _a[i][j] }
   toString { _a.toString }
   // rounds all elements to 'p' places
   toString(p) {
       var s = List.filled(rows, "")
       var pow = 10.pow(p)
       for (i in 0...rows) {
           var t = List.filled(cols, "")
           for (j in 0...cols) {
               var r = (_a[i][j]*pow).round / pow
               t[j] = r.toString
               if (t[j] == "-0") t[j] = "0"
           }
           s[i] = t.toString
       }
       return s
   }

}

var params = Fn.new { |r, c|

   return [
       [0...r, 0...c, 0, 0],
       [0...r, c...2*c, 0, c],
       [r...2*r, 0...c, r, 0],
       [r...2*r, c...2*c, r, c]
   ]

}

var toQuarters = Fn.new { |m|

   var r = (m.rows/2).floor
   var c = (m.cols/2).floor
   var p = params.call(r, c)
   var quarters = []
   for (k in 0..3) {
       var q = List.filled(r, null)
       for (i in p[k][0]) {
           q[i - p[k][2]] = List.filled(c, 0)
           for (j in p[k][1]) q[i - p[k][2]][j - p[k][3]] = m[i, j]
       }
       quarters.add(Matrix.new(q))
   }
   return quarters

}

var fromQuarters = Fn.new { |q|

   var r = q[0].rows
   var c = q[0].cols
   var p = params.call(r, c)
   r = r * 2
   c = c * 2
   var m = List.filled(r, null)
   for (i in 0...c) m[i] = List.filled(c, 0)
   for (k in 0..3) {
       for (i in p[k][0]) {
           for (j in p[k][1]) m[i][j] = q[k][i - p[k][2], j - p[k][3]]
       }
   }
   return Matrix.new(m)

}

var strassen // recursive strassen = Fn.new { |a, b|

   if (a.rows != a.cols || b.rows != b.cols || a.rows != b.rows) {
       Fiber.abort("Matrices must be square and of equal size.")
   }
   if (a.rows == 0 || (a.rows & (a.rows - 1)) != 0) {
       Fiber.abort("Size of matrices must be a power of two.")
   }
   if (a.rows == 1) return a * b
   var qa = toQuarters.call(a)
   var qb = toQuarters.call(b)
   System.write("") // guard against VM recursion bug
   var p1 = strassen.call(qa[1] - qa[3], qb[2] + qb[3])
   var p2 = strassen.call(qa[0] + qa[3], qb[0] + qb[3])
   var p3 = strassen.call(qa[0] - qa[2], qb[0] + qb[1])
   var p4 = strassen.call(qa[0] + qa[1], qb[3])
   var p5 = strassen.call(qa[0], qb[1] - qb[3])
   var p6 = strassen.call(qa[3], qb[2] - qb[0])
   var p7 = strassen.call(qa[2] + qa[3], qb[0])
   var q = List.filled(4, null)
   q[0] = p1 + p2 - p4 + p6
   q[1] = p4 + p5
   q[2] = p6 + p7
   q[3] = p2 - p3 + p5 - p7
   return fromQuarters.call(q)

}

var a = Matrix.new([ [1,2], [3, 4] ]) var b = Matrix.new([ [5,6], [7, 8] ]) var c = Matrix.new([ [1, 1, 1, 1], [2, 4, 8, 16], [3, 9, 27, 81], [4, 16, 64, 256] ]) var d = Matrix.new([ [4, -3, 4/3, -1/4], [-13/3, 19/4, -7/3, 11/24],

                    [3/2, -2, 7/6, -1/4], [-1/6, 1/4, -1/6, 1/24] ])

var e = Matrix.new([ [1, 2, 3, 4], [5, 6, 7, 8], [9,10,11,12], [13,14,15,16] ]) var f = Matrix.new([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ]) System.print("Using 'normal' matrix multiplication:") System.print(" a * b = %(a * b)") System.print(" c * d = %((c * d).toString(6))") System.print(" e * f = %(e * f)") System.print("\nUsing 'Strassen' matrix multiplication:") System.print(" a * b = %(strassen.call(a, b))") System.print(" c * d = %(strassen.call(c, d).toString(6))") System.print(" e * f = %(strassen.call(e, f))")</lang>

Output:
Using 'normal' matrix multiplication:
  a * b = [[19, 22], [43, 50]]
  c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
  e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]

Using 'Strassen' matrix multiplication:
  a * b = [[19, 22], [43, 50]]
  c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
  e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]