Strassen's algorithm: Difference between revisions

m
→‎{{header|Wren}}: Changed to Wren S/H
m (Added references)
m (→‎{{header|Wren}}: Changed to Wren S/H)
 
(5 intermediate revisions by 4 users not shown)
Line 24:
{{trans|Wren}}
Rather than use a library such as gonum, we create a simple Matrix type which is adequate for this task.
<langsyntaxhighlight lang="go">package main
 
import (
Line 188:
fmt.Printf(" c * d = %v\n", strassen(c, d).toString(6))
fmt.Printf(" e * f = %v\n", strassen(e, f))
}</langsyntaxhighlight>
 
{{out}}
Line 206:
===With dynamic padding===
Because Julia uses column major in matrices, sometimes the code uses the adjoint of a matrix in order to match examples as written.
<langsyntaxhighlight lang="julia">"""
Strassen's matrix multiplication algorithm.
Use dynamic padding in order to reduce required auxiliary memory.
Line 323:
intprint("Regular multiply: ", R * R)
intprint("Strassen multiply: ", strassen(R,R))
</langsyntaxhighlight>{{out}}
<pre>
Regular multiply: [19 43; 22 50]
Line 337:
===Recursive===
Output is the same as the dynamically padded version.
<langsyntaxhighlight Julialang="julia">function Strassen(A, B)
n = size(A, 1)
if n == 1
Line 387:
intprint("Regular multiply: ", R * R)
intprint("Strassen multiply: ", Strassen(R,R))
</syntaxhighlight>
</lang>
 
=={{header|MATLAB}}==
<syntaxhighlight lang="MATLAB}}">
clear all;close all;clc;
 
A = [1, 2; 3, 4];
B = [5, 6; 7, 8];
C = [1, 1, 1, 1; 2, 4, 8, 16; 3, 9, 27, 81; 4, 16, 64, 256];
D = [4, -3, 4/3, -1/4; -13/3, 19/4, -7/3, 11/24; 3/2, -2, 7/6, -1/4; -1/6, 1/4, -1/6, 1/24];
E = [1, 2, 3, 4; 5, 6, 7, 8; 9, 10, 11, 12; 13, 14, 15, 16];
F = eye(4);
 
disp('Regular multiply: ');
disp(A' * B');
 
disp('Strassen multiply: ');
disp(Strassen(A', B'));
 
disp('Regular multiply: ');
disp(C * D);
 
disp('Strassen multiply: ');
disp(Strassen(C, D));
 
disp('Regular multiply: ');
disp(E * F);
 
disp('Strassen multiply: ');
disp(Strassen(E, F));
 
r = sqrt(2)/2;
R = [r, r; -r, r];
 
disp('Regular multiply: ');
disp(R * R);
 
disp('Strassen multiply: ');
disp(Strassen(R, R));
 
 
function C = Strassen(A, B)
n = size(A, 1);
if n == 1
C = A * B;
return
end
A11 = A(1:n/2, 1:n/2);
A12 = A(1:n/2, n/2+1:n);
A21 = A(n/2+1:n, 1:n/2);
A22 = A(n/2+1:n, n/2+1:n);
B11 = B(1:n/2, 1:n/2);
B12 = B(1:n/2, n/2+1:n);
B21 = B(n/2+1:n, 1:n/2);
B22 = B(n/2+1:n, n/2+1:n);
 
P1 = Strassen(A12 - A22, B21 + B22);
P2 = Strassen(A11 + A22, B11 + B22);
P3 = Strassen(A11 - A21, B11 + B12);
P4 = Strassen(A11 + A12, B22);
P5 = Strassen(A11, B12 - B22);
P6 = Strassen(A22, B21 - B11);
P7 = Strassen(A21 + A22, B11);
 
C11 = P1 + P2 - P4 + P6;
C12 = P4 + P5;
C21 = P6 + P7;
C22 = P2 - P3 + P5 - P7;
 
C = [C11 C12; C21 C22];
end
</syntaxhighlight>
{{out}}
<pre>
Regular multiply:
23 31
34 46
 
Strassen multiply:
23 31
34 46
 
Regular multiply:
1.0000 0 -0.0000 -0.0000
0.0000 1.0000 -0.0000 -0.0000
0 0 1.0000 0
0.0000 0 0.0000 1.0000
 
Strassen multiply:
1.0000 0.0000 -0.0000 -0.0000
-0.0000 1.0000 -0.0000 0.0000
0 0 1.0000 0.0000
0 0 -0.0000 1.0000
 
Regular multiply:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
 
Strassen multiply:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
 
Regular multiply:
0 1.0000
-1.0000 0
 
Strassen multiply:
0 1.0000
-1.0000 0
 
</pre>
 
 
=={{header|Nim}}==
{{trans|Go}}
{{trans|Wren}}
<langsyntaxhighlight Nimlang="nim">import math, sequtils, strutils
 
type Matrix = seq[seq[float]]
Line 531 ⟶ 646:
echo " a * b = ", strassen(a, b).toString(10)
echo " c * d = ", strassen(c, d).toString(6)
echo " e * f = ", strassen(e, f).toString(10)</langsyntaxhighlight>
 
{{out}}
Line 546 ⟶ 661:
=={{header|Phix}}==
As noted on wp, you could pad with zeroes, and strip them on exit, instead of crashing for non-square 2<sup><small>n</small></sup> matrices.
<!--<langsyntaxhighlight Phixlang="phix">(phixonline)-->
<span style="color: #008080;">with</span> <span style="color: #008080;">javascript_semantics</span>
<span style="color: #008080;">function</span> <span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #004080;">sequence</span> <span style="color: #000000;">a</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">b</span><span style="color: #0000FF;">)</span>
Line 629 ⟶ 744:
<span style="color: #0000FF;">{-</span><span style="color: #000000;">r</span><span style="color: #0000FF;">,</span><span style="color: #000000;">r</span><span style="color: #0000FF;">}}</span>
<span style="color: #7060A8;">pp</span><span style="color: #0000FF;">(</span><span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #000000;">R</span><span style="color: #0000FF;">,</span><span style="color: #000000;">R</span><span style="color: #0000FF;">))</span>
<!--</langsyntaxhighlight>-->
{{out}}
Matches that of [[Matrix_multiplication#Phix]], when given the same inputs. Note that a few "-0" show up in the second one (the identity matrix) under pwa/p2js.
Line 648 ⟶ 763:
 
=={{header|Python}}==
<langsyntaxhighlight lang="python">"""Matrix multiplication using Strassen's algorithm. Requires Python >= 3.7."""
 
from __future__ import annotations
Line 814 ⟶ 929:
if __name__ == "__main__":
examples()
</syntaxhighlight>
</lang>
 
{{out}}
Line 831 ⟶ 946:
Special thanks go to the module author, [https://github.com/frithnanth Fernando Santagata], on showing how to deal with a pass-by-value case.
{{trans|Julia}}
<syntaxhighlight lang="raku" perl6line># 20210126 Raku programming solution
 
use Math::Libgsl::Constants;
Line 924 ⟶ 1,039:
1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16
1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1
=end code</langsyntaxhighlight>
{{out}}
<pre>Regular multiply:
Line 952 ⟶ 1,067:
9 10 11 12
13 14 15 16</pre>
 
 
=={{header|Scala}}==
{{trans|Go}}
<syntaxhighlight lang="Scala">
import scala.math
 
object MatrixOperations {
 
type Matrix = Array[Array[Double]]
 
implicit class RichMatrix(val m: Matrix) {
def rows: Int = m.length
def cols: Int = m(0).length
 
def add(m2: Matrix): Matrix = {
require(
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
)
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) + m2(i)(j))
}
 
def sub(m2: Matrix): Matrix = {
require(
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
)
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) - m2(i)(j))
}
 
def mul(m2: Matrix): Matrix = {
require(m.cols == m2.rows, "Cannot multiply these matrices.")
Array.tabulate(m.rows, m2.cols)((i, j) =>
(0 until m.cols).map(k => m(i)(k) * m2(k)(j)).sum
)
}
 
def toString(p: Int): String = {
val pow = math.pow(10, p)
m.map(row =>
row
.map(value => (math.round(value * pow) / pow).toString)
.mkString("[", ", ", "]")
).mkString("[", ",\n ", "]")
}
}
 
def toQuarters(m: Matrix): Array[Matrix] = {
val r = m.rows / 2
val c = m.cols / 2
val p = params(r, c)
(0 until 4).map { k =>
Array.tabulate(r, c)((i, j) => m(p(k)(0) + i)(p(k)(2) + j))
}.toArray
}
 
def fromQuarters(q: Array[Matrix]): Matrix = {
val r = q(0).rows
val c = q(0).cols
val p = params(r, c)
Array.tabulate(r * 2, c * 2)((i, j) => q((i / r) * 2 + j / c)(i % r)(j % c))
}
 
def strassen(a: Matrix, b: Matrix): Matrix = {
require(
a.rows == a.cols && b.rows == b.cols && a.rows == b.rows,
"Matrices must be square and of equal size."
)
require(
a.rows != 0 && (a.rows & (a.rows - 1)) == 0,
"Size of matrices must be a power of two."
)
 
if (a.rows == 1) {
return a.mul(b)
}
 
val qa = toQuarters(a)
val qb = toQuarters(b)
 
val p1 = strassen(qa(1).sub(qa(3)), qb(2).add(qb(3)))
val p2 = strassen(qa(0).add(qa(3)), qb(0).add(qb(3)))
val p3 = strassen(qa(0).sub(qa(2)), qb(0).add(qb(1)))
val p4 = strassen(qa(0).add(qa(1)), qb(3))
val p5 = strassen(qa(0), qb(1).sub(qb(3)))
val p6 = strassen(qa(3), qb(2).sub(qb(0)))
val p7 = strassen(qa(2).add(qa(3)), qb(0))
 
val q = Array(
p1.add(p2).sub(p4).add(p6),
p4.add(p5),
p6.add(p7),
p2.sub(p3).add(p5).sub(p7)
)
 
fromQuarters(q)
}
 
private def params(r: Int, c: Int): Array[Array[Int]] = {
Array(
Array(0, r, 0, c, 0, 0),
Array(0, r, c, 2 * c, 0, c),
Array(r, 2 * r, 0, c, r, 0),
Array(r, 2 * r, c, 2 * c, r, c)
)
}
 
def main(args: Array[String]): Unit = {
val a: Matrix = Array(Array(1.0, 2.0), Array(3.0, 4.0))
val b: Matrix = Array(Array(5.0, 6.0), Array(7.0, 8.0))
val c: Matrix = Array(
Array(1.0, 1.0, 1.0, 1.0),
Array(2.0, 4.0, 8.0, 16.0),
Array(3.0, 9.0, 27.0, 81.0),
Array(4.0, 16.0, 64.0, 256.0)
)
val d: Matrix = Array(
Array(4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0),
Array(-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0),
Array(3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0),
Array(-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0)
)
val e: Matrix = Array(
Array(1.0, 2.0, 3.0, 4.0),
Array(5.0, 6.0, 7.0, 8.0),
Array(9.0, 10.0, 11.0, 12.0),
Array(13.0, 14.0, 15.0, 16.0)
)
val f: Matrix = Array(
Array(1.0, 0.0, 0.0, 0.0),
Array(0.0, 1.0, 0.0, 0.0),
Array(0.0, 0.0, 1.0, 0.0),
Array(0.0, 0.0, 0.0, 1.0)
)
 
println("Using 'normal' matrix multiplication:")
println(
s" a * b = ${a.mul(b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
println(s" c * d = ${c.mul(d).toString(6)}")
println(
s" e * f = ${e.mul(f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
 
println("\nUsing 'Strassen' matrix multiplication:")
println(
s" a * b = ${strassen(a, b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
println(s" c * d = ${strassen(c, d).toString(6)}")
println(
s" e * f = ${strassen(e, f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
}
}
</syntaxhighlight>
{{out}}
<pre>
Using 'normal' matrix multiplication:
a * b = [[19.0, 22.0], [43.0, 50.0]]
c * d = [[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]]
e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]
 
Using 'Strassen' matrix multiplication:
a * b = [[19.0, 22.0], [43.0, 50.0]]
c * d = [[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]]
e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]
 
</pre>
 
=={{header|Swift}}==
[https://github.com/hollance/Matrix/blob/master/Matrix.swift '''Matrix Class'''] by [https://github.com/ozgurshn ozgurshn]
<langsyntaxhighlight lang="swift">
// Matrix Strassen Multiplication
func strassenMultiply(matrix1: Matrix, matrix2: Matrix) -> Matrix {
Line 966 ⟶ 1,256:
var sqrMatrix1 = Matrix(rows: pwr2, columns: pwr2)
var sqrMatrix2 = Matrix(rows: pwr2, columns: pwr2)
 
// fill square matrix 1 with values
for i in 0..<matrix1.rows {
for j in 0..<matrix1.columns{
Line 972 ⟶ 1,263:
}
}
// fill square matrix 2 with values
for i in 0..<matrix2.rows {
for j in 0..<matrix2.columns{
Line 991 ⟶ 1,283:
// Calculate next power of 2
func nextPowerOfTwo(num: Int) -> Int {
// formula for next power of 2
return Int(pow(2,(ceil(log2(Double(num))))))
}
Line 1,001 ⟶ 1,294:
let rowHalf = matrix1.rows / 2
// Strassen Formula https://www.geeksforgeeks.org/easy-way-remember-strassens-matrix-equation/
// p1 = a(f-h) p2 = (a+b)h
// p2 = (c+d)e p4 = d(g-e)
// p5 = (a+d)(e+h) p6 = (b-d)(g+h)
// p7 = (a-c)(e+f)
|a b| x |e f| = |(p5+p4-p2+p6) (p1+p2)|
|c d| |g h| |(p3+p4) (p1+p5-p3-p7)|
Matrix 1 Matrix 2 Result
 
// create empty matrices for a, b, c, d, e, f, g, h
var a = Matrix(rows: rowHalf, columns: rowHalf)
var b = Matrix(rows: rowHalf, columns: rowHalf)
Line 1,009 ⟶ 1,311:
var g = Matrix(rows: rowHalf, columns: rowHalf)
var h = Matrix(rows: rowHalf, columns: rowHalf)
 
// fill the matrices with values
for i in 0..<rowHalf {
for j in 0..<rowHalf {
Line 1,023 ⟶ 1,326:
}
let p1 = strassenFormula(matrix1: a, matrix2: (f - h)) // a * (f - h)
let p2p1 = strassenFormula(matrix1: (a + b), matrix2: (f - h)) // (a + b) * h
// (a + b) * h
let p3 = strassenFormula(matrix1: (c + d), matrix2: e) // (c + d) * e
let p4p2 = strassenFormula(matrix1: d(a + b), matrix2: (g - e)h) // d * (g - e)
let// p5 = strassenFormula(matrix1: (ac + d), matrix2:* (e + h)) // (a + d) * (e + h)
let p6p3 = strassenFormula(matrix1: (bc -+ d), matrix2: (g + h)e) // (b - d) * (g + h)
let// p7d =* strassenFormula(matrix1: (ag - c), matrix2: (e + f)) // (a - c) * (e + f)
let p4 = strassenFormula(matrix1: d, matrix2: (g - e))
 
let// result11 = p5(a + p4d) -* p2(e + p6h) // p5 + p4 - p2 + p6
let p5 = strassenFormula(matrix1: (a + d), matrix2: (e + h))
let result12 = p1 + p2 // p1 + p2
let// result21(b =- p3d) +* p4(g + h) // p3 + p4
let result22p6 = p1strassenFormula(matrix1: + p5(b - p3d), -matrix2: p7 // p1(g + p5 - p3 - p7h))
// (a - c) * (e + f)
let p7 = strassenFormula(matrix1: (a - c), matrix2: (e + f))
// p5 + p4 - p2 + p6
let result11 = p5 + p4 - p2 + p6
// p1 + p2
let result12 = p1 + p2
// p3 + p4
let result21 = p3 + p4
// p1 + p5 - p3 - p7
let result22 = p1 + p5 - p3 - p7
 
// create an empty matrix for result and fill with values
var result = Matrix(rows: matrix1.rows, columns: matrix1.rows)
for i in 0..<rowHalf {
Line 1,093 ⟶ 1,408:
print(result3.description)
}
main()</langsyntaxhighlight>
 
{{out}}
Line 1,116 ⟶ 1,431:
 
=={{header|Wren}}==
{{libheader|Wren-fmt}}
Wren doesn't currently have a matrix module so I've written a rudimentary Matrix class with sufficient functionality to complete this task.
 
I've used the Phix entry's examples to test the Strassen algorithm implementation.
<syntaxhighlight lang="wren">class Matrix {
<lang ecmascript>import "/fmt" for Fmt
 
class Matrix {
construct new(a) {
if (a.type != List || a.count == 0 || a[0].type != List || a[0].count == 0 || a[0][0].type != Num) {
Line 1,275 ⟶ 1,587:
System.print(" a * b = %(strassen.call(a, b))")
System.print(" c * d = %(strassen.call(c, d).toString(6))")
System.print(" e * f = %(strassen.call(e, f))")</langsyntaxhighlight>
 
{{out}}
Line 1,289 ⟶ 1,601:
e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
</pre>
<br>
{{libheader|Wren-matrix}}
Since the above version was written, a Matrix module has been added and the following version uses it. The output is exactly the same as before.
<syntaxhighlight lang="wren">import "./matrix" for Matrix
var params = Fn.new { |r, c|
return [
[0...r, 0...c, 0, 0],
[0...r, c...2*c, 0, c],
[r...2*r, 0...c, r, 0],
[r...2*r, c...2*c, r, c]
]
}
var toQuarters = Fn.new { |m|
var r = (m.numRows/2).floor
var c = (m.numCols/2).floor
var p = params.call(r, c)
var quarters = []
for (k in 0..3) {
var q = List.filled(r, null)
for (i in p[k][0]) {
q[i - p[k][2]] = List.filled(c, 0)
for (j in p[k][1]) q[i - p[k][2]][j - p[k][3]] = m[i, j]
}
quarters.add(Matrix.new(q))
}
return quarters
}
var fromQuarters = Fn.new { |q|
var r = q[0].numRows
var c = q[0].numCols
var p = params.call(r, c)
r = r * 2
c = c * 2
var m = List.filled(r, null)
for (i in 0...c) m[i] = List.filled(c, 0)
for (k in 0..3) {
for (i in p[k][0]) {
for (j in p[k][1]) m[i][j] = q[k][i - p[k][2], j - p[k][3]]
}
}
return Matrix.new(m)
}
var strassen // recursive
strassen = Fn.new { |a, b|
if (!a.isSquare || !b.isSquare || !a.sameSize(b)) {
Fiber.abort("Matrices must be square and of equal size.")
}
if (a.numRows == 0 || (a.numRows & (a.numRows - 1)) != 0) {
Fiber.abort("Size of matrices must be a power of two.")
}
if (a.numRows == 1) return a * b
var qa = toQuarters.call(a)
var qb = toQuarters.call(b)
var p1 = strassen.call(qa[1] - qa[3], qb[2] + qb[3])
var p2 = strassen.call(qa[0] + qa[3], qb[0] + qb[3])
var p3 = strassen.call(qa[0] - qa[2], qb[0] + qb[1])
var p4 = strassen.call(qa[0] + qa[1], qb[3])
var p5 = strassen.call(qa[0], qb[1] - qb[3])
var p6 = strassen.call(qa[3], qb[2] - qb[0])
var p7 = strassen.call(qa[2] + qa[3], qb[0])
var q = List.filled(4, null)
q[0] = p1 + p2 - p4 + p6
q[1] = p4 + p5
q[2] = p6 + p7
q[3] = p2 - p3 + p5 - p7
return fromQuarters.call(q)
}
var a = Matrix.new([ [1,2], [3, 4] ])
var b = Matrix.new([ [5,6], [7, 8] ])
var c = Matrix.new([ [1, 1, 1, 1], [2, 4, 8, 16], [3, 9, 27, 81], [4, 16, 64, 256] ])
var d = Matrix.new([ [4, -3, 4/3, -1/4], [-13/3, 19/4, -7/3, 11/24],
[3/2, -2, 7/6, -1/4], [-1/6, 1/4, -1/6, 1/24] ])
var e = Matrix.new([ [1, 2, 3, 4], [5, 6, 7, 8], [9,10,11,12], [13,14,15,16] ])
var f = Matrix.new([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ])
System.print("Using 'normal' matrix multiplication:")
System.print(" a * b = %(a * b)")
System.print(" c * d = %((c * d).toString(6))")
System.print(" e * f = %(e * f)")
System.print("\nUsing 'Strassen' matrix multiplication:")
System.print(" a * b = %(strassen.call(a, b))")
System.print(" c * d = %(strassen.call(c, d).toString(6))")
System.print(" e * f = %(strassen.call(e, f))")</syntaxhighlight>
9,476

edits