Strassen's algorithm: Difference between revisions

m
→‎{{header|Wren}}: Changed to Wren S/H
m (syntax highlighting fixup automation)
m (→‎{{header|Wren}}: Changed to Wren S/H)
 
(2 intermediate revisions by one other user not shown)
Line 388:
intprint("Strassen multiply: ", Strassen(R,R))
</syntaxhighlight>
 
=={{header|MATLAB}}==
<syntaxhighlight lang="MATLAB}}">
clear all;close all;clc;
 
A = [1, 2; 3, 4];
B = [5, 6; 7, 8];
C = [1, 1, 1, 1; 2, 4, 8, 16; 3, 9, 27, 81; 4, 16, 64, 256];
D = [4, -3, 4/3, -1/4; -13/3, 19/4, -7/3, 11/24; 3/2, -2, 7/6, -1/4; -1/6, 1/4, -1/6, 1/24];
E = [1, 2, 3, 4; 5, 6, 7, 8; 9, 10, 11, 12; 13, 14, 15, 16];
F = eye(4);
 
disp('Regular multiply: ');
disp(A' * B');
 
disp('Strassen multiply: ');
disp(Strassen(A', B'));
 
disp('Regular multiply: ');
disp(C * D);
 
disp('Strassen multiply: ');
disp(Strassen(C, D));
 
disp('Regular multiply: ');
disp(E * F);
 
disp('Strassen multiply: ');
disp(Strassen(E, F));
 
r = sqrt(2)/2;
R = [r, r; -r, r];
 
disp('Regular multiply: ');
disp(R * R);
 
disp('Strassen multiply: ');
disp(Strassen(R, R));
 
 
function C = Strassen(A, B)
n = size(A, 1);
if n == 1
C = A * B;
return
end
A11 = A(1:n/2, 1:n/2);
A12 = A(1:n/2, n/2+1:n);
A21 = A(n/2+1:n, 1:n/2);
A22 = A(n/2+1:n, n/2+1:n);
B11 = B(1:n/2, 1:n/2);
B12 = B(1:n/2, n/2+1:n);
B21 = B(n/2+1:n, 1:n/2);
B22 = B(n/2+1:n, n/2+1:n);
 
P1 = Strassen(A12 - A22, B21 + B22);
P2 = Strassen(A11 + A22, B11 + B22);
P3 = Strassen(A11 - A21, B11 + B12);
P4 = Strassen(A11 + A12, B22);
P5 = Strassen(A11, B12 - B22);
P6 = Strassen(A22, B21 - B11);
P7 = Strassen(A21 + A22, B11);
 
C11 = P1 + P2 - P4 + P6;
C12 = P4 + P5;
C21 = P6 + P7;
C22 = P2 - P3 + P5 - P7;
 
C = [C11 C12; C21 C22];
end
</syntaxhighlight>
{{out}}
<pre>
Regular multiply:
23 31
34 46
 
Strassen multiply:
23 31
34 46
 
Regular multiply:
1.0000 0 -0.0000 -0.0000
0.0000 1.0000 -0.0000 -0.0000
0 0 1.0000 0
0.0000 0 0.0000 1.0000
 
Strassen multiply:
1.0000 0.0000 -0.0000 -0.0000
-0.0000 1.0000 -0.0000 0.0000
0 0 1.0000 0.0000
0 0 -0.0000 1.0000
 
Regular multiply:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
 
Strassen multiply:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
 
Regular multiply:
0 1.0000
-1.0000 0
 
Strassen multiply:
0 1.0000
-1.0000 0
 
</pre>
 
 
=={{header|Nim}}==
Line 952 ⟶ 1,067:
9 10 11 12
13 14 15 16</pre>
 
 
=={{header|Scala}}==
{{trans|Go}}
<syntaxhighlight lang="Scala">
import scala.math
 
object MatrixOperations {
 
type Matrix = Array[Array[Double]]
 
implicit class RichMatrix(val m: Matrix) {
def rows: Int = m.length
def cols: Int = m(0).length
 
def add(m2: Matrix): Matrix = {
require(
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
)
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) + m2(i)(j))
}
 
def sub(m2: Matrix): Matrix = {
require(
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
)
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) - m2(i)(j))
}
 
def mul(m2: Matrix): Matrix = {
require(m.cols == m2.rows, "Cannot multiply these matrices.")
Array.tabulate(m.rows, m2.cols)((i, j) =>
(0 until m.cols).map(k => m(i)(k) * m2(k)(j)).sum
)
}
 
def toString(p: Int): String = {
val pow = math.pow(10, p)
m.map(row =>
row
.map(value => (math.round(value * pow) / pow).toString)
.mkString("[", ", ", "]")
).mkString("[", ",\n ", "]")
}
}
 
def toQuarters(m: Matrix): Array[Matrix] = {
val r = m.rows / 2
val c = m.cols / 2
val p = params(r, c)
(0 until 4).map { k =>
Array.tabulate(r, c)((i, j) => m(p(k)(0) + i)(p(k)(2) + j))
}.toArray
}
 
def fromQuarters(q: Array[Matrix]): Matrix = {
val r = q(0).rows
val c = q(0).cols
val p = params(r, c)
Array.tabulate(r * 2, c * 2)((i, j) => q((i / r) * 2 + j / c)(i % r)(j % c))
}
 
def strassen(a: Matrix, b: Matrix): Matrix = {
require(
a.rows == a.cols && b.rows == b.cols && a.rows == b.rows,
"Matrices must be square and of equal size."
)
require(
a.rows != 0 && (a.rows & (a.rows - 1)) == 0,
"Size of matrices must be a power of two."
)
 
if (a.rows == 1) {
return a.mul(b)
}
 
val qa = toQuarters(a)
val qb = toQuarters(b)
 
val p1 = strassen(qa(1).sub(qa(3)), qb(2).add(qb(3)))
val p2 = strassen(qa(0).add(qa(3)), qb(0).add(qb(3)))
val p3 = strassen(qa(0).sub(qa(2)), qb(0).add(qb(1)))
val p4 = strassen(qa(0).add(qa(1)), qb(3))
val p5 = strassen(qa(0), qb(1).sub(qb(3)))
val p6 = strassen(qa(3), qb(2).sub(qb(0)))
val p7 = strassen(qa(2).add(qa(3)), qb(0))
 
val q = Array(
p1.add(p2).sub(p4).add(p6),
p4.add(p5),
p6.add(p7),
p2.sub(p3).add(p5).sub(p7)
)
 
fromQuarters(q)
}
 
private def params(r: Int, c: Int): Array[Array[Int]] = {
Array(
Array(0, r, 0, c, 0, 0),
Array(0, r, c, 2 * c, 0, c),
Array(r, 2 * r, 0, c, r, 0),
Array(r, 2 * r, c, 2 * c, r, c)
)
}
 
def main(args: Array[String]): Unit = {
val a: Matrix = Array(Array(1.0, 2.0), Array(3.0, 4.0))
val b: Matrix = Array(Array(5.0, 6.0), Array(7.0, 8.0))
val c: Matrix = Array(
Array(1.0, 1.0, 1.0, 1.0),
Array(2.0, 4.0, 8.0, 16.0),
Array(3.0, 9.0, 27.0, 81.0),
Array(4.0, 16.0, 64.0, 256.0)
)
val d: Matrix = Array(
Array(4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0),
Array(-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0),
Array(3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0),
Array(-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0)
)
val e: Matrix = Array(
Array(1.0, 2.0, 3.0, 4.0),
Array(5.0, 6.0, 7.0, 8.0),
Array(9.0, 10.0, 11.0, 12.0),
Array(13.0, 14.0, 15.0, 16.0)
)
val f: Matrix = Array(
Array(1.0, 0.0, 0.0, 0.0),
Array(0.0, 1.0, 0.0, 0.0),
Array(0.0, 0.0, 1.0, 0.0),
Array(0.0, 0.0, 0.0, 1.0)
)
 
println("Using 'normal' matrix multiplication:")
println(
s" a * b = ${a.mul(b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
println(s" c * d = ${c.mul(d).toString(6)}")
println(
s" e * f = ${e.mul(f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
 
println("\nUsing 'Strassen' matrix multiplication:")
println(
s" a * b = ${strassen(a, b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
println(s" c * d = ${strassen(c, d).toString(6)}")
println(
s" e * f = ${strassen(e, f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
}
}
</syntaxhighlight>
{{out}}
<pre>
Using 'normal' matrix multiplication:
a * b = [[19.0, 22.0], [43.0, 50.0]]
c * d = [[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]]
e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]
 
Using 'Strassen' matrix multiplication:
a * b = [[19.0, 22.0], [43.0, 50.0]]
c * d = [[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]]
e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]
 
</pre>
 
=={{header|Swift}}==
Line 1,144 ⟶ 1,434:
 
I've used the Phix entry's examples to test the Strassen algorithm implementation.
<syntaxhighlight lang="ecmascriptwren">class Matrix {
construct new(a) {
if (a.type != List || a.count == 0 || a[0].type != List || a[0].count == 0 || a[0][0].type != Num) {
Line 1,314 ⟶ 1,604:
{{libheader|Wren-matrix}}
Since the above version was written, a Matrix module has been added and the following version uses it. The output is exactly the same as before.
<syntaxhighlight lang="ecmascriptwren">import "./matrix" for Matrix
var params = Fn.new { |r, c|
9,476

edits