Strassen's algorithm: Difference between revisions

m
→‎{{header|Wren}}: Changed to Wren S/H
m (→‎{{header|Wren}}: Changed to Wren S/H)
 
(13 intermediate revisions by 7 users not shown)
Line 1:
{{draft task}}
 
;Description
In linear algebra, the Strassen algorithm,   (named after Volker Strassen),   is an algorithm for matrix multiplication. It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices, but would be slower than the fastest known algorithms for extremely large matrices.
 
It is faster than the standard matrix multiplication algorithm and is useful in practice for large matrices,   but would be slower than the fastest known algorithms for extremely large matrices.
 
 
;Task
Write a routine, function, procedure etc. in your language to implement the Strassen algorithm for matrix multiplication.
 
While practical implementations of Strassen's algorithm usually switch to standard methods of matrix multiplication for small enough submatricessub-matrices (currently anything <less 512x512than &nbsp; 512&times;512 &nbsp; according to wpWikipedia), &nbsp; for the purposes of this task you should not switch until reaching a size of 1 or 2.
 
 
;Related task
:* [[Matrix multiplication]]
 
 
;See also
:* [[wp:Strassen algorithm|Wikipedia article]]
<br><br>
 
 
=={{header|Go}}==
{{trans|Wren}}
Rather than use a library such as gonum, we create a simple Matrix type which is adequate for this task.
<langsyntaxhighlight lang="go">package main
 
import (
Line 182 ⟶ 188:
fmt.Printf(" c * d = %v\n", strassen(c, d).toString(6))
fmt.Printf(" e * f = %v\n", strassen(e, f))
}</langsyntaxhighlight>
 
{{out}}
Line 200 ⟶ 206:
===With dynamic padding===
Because Julia uses column major in matrices, sometimes the code uses the adjoint of a matrix in order to match examples as written.
<langsyntaxhighlight lang="julia">"""
Strassen's matrix multiplication algorithm.
Use dynamic padding in order to reduce required auxiliary memory.
Line 317 ⟶ 323:
intprint("Regular multiply: ", R * R)
intprint("Strassen multiply: ", strassen(R,R))
</langsyntaxhighlight>{{out}}
<pre>
Regular multiply: [19 43; 22 50]
Line 331 ⟶ 337:
===Recursive===
Output is the same as the dynamically padded version.
<langsyntaxhighlight Julialang="julia">function Strassen(A, B)
n = size(A, 1)
if n == 1
Line 381 ⟶ 387:
intprint("Regular multiply: ", R * R)
intprint("Strassen multiply: ", Strassen(R,R))
</syntaxhighlight>
</lang>
 
=={{header|MATLAB}}==
<syntaxhighlight lang="MATLAB}}">
clear all;close all;clc;
 
A = [1, 2; 3, 4];
B = [5, 6; 7, 8];
C = [1, 1, 1, 1; 2, 4, 8, 16; 3, 9, 27, 81; 4, 16, 64, 256];
D = [4, -3, 4/3, -1/4; -13/3, 19/4, -7/3, 11/24; 3/2, -2, 7/6, -1/4; -1/6, 1/4, -1/6, 1/24];
E = [1, 2, 3, 4; 5, 6, 7, 8; 9, 10, 11, 12; 13, 14, 15, 16];
F = eye(4);
 
disp('Regular multiply: ');
disp(A' * B');
 
disp('Strassen multiply: ');
disp(Strassen(A', B'));
 
disp('Regular multiply: ');
disp(C * D);
 
disp('Strassen multiply: ');
disp(Strassen(C, D));
 
disp('Regular multiply: ');
disp(E * F);
 
disp('Strassen multiply: ');
disp(Strassen(E, F));
 
r = sqrt(2)/2;
R = [r, r; -r, r];
 
disp('Regular multiply: ');
disp(R * R);
 
disp('Strassen multiply: ');
disp(Strassen(R, R));
 
 
function C = Strassen(A, B)
n = size(A, 1);
if n == 1
C = A * B;
return
end
A11 = A(1:n/2, 1:n/2);
A12 = A(1:n/2, n/2+1:n);
A21 = A(n/2+1:n, 1:n/2);
A22 = A(n/2+1:n, n/2+1:n);
B11 = B(1:n/2, 1:n/2);
B12 = B(1:n/2, n/2+1:n);
B21 = B(n/2+1:n, 1:n/2);
B22 = B(n/2+1:n, n/2+1:n);
 
P1 = Strassen(A12 - A22, B21 + B22);
P2 = Strassen(A11 + A22, B11 + B22);
P3 = Strassen(A11 - A21, B11 + B12);
P4 = Strassen(A11 + A12, B22);
P5 = Strassen(A11, B12 - B22);
P6 = Strassen(A22, B21 - B11);
P7 = Strassen(A21 + A22, B11);
 
C11 = P1 + P2 - P4 + P6;
C12 = P4 + P5;
C21 = P6 + P7;
C22 = P2 - P3 + P5 - P7;
 
C = [C11 C12; C21 C22];
end
</syntaxhighlight>
{{out}}
<pre>
Regular multiply:
23 31
34 46
 
Strassen multiply:
23 31
34 46
 
Regular multiply:
1.0000 0 -0.0000 -0.0000
0.0000 1.0000 -0.0000 -0.0000
0 0 1.0000 0
0.0000 0 0.0000 1.0000
 
Strassen multiply:
1.0000 0.0000 -0.0000 -0.0000
-0.0000 1.0000 -0.0000 0.0000
0 0 1.0000 0.0000
0 0 -0.0000 1.0000
 
Regular multiply:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
 
Strassen multiply:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
 
Regular multiply:
0 1.0000
-1.0000 0
 
Strassen multiply:
0 1.0000
-1.0000 0
 
</pre>
 
 
=={{header|Nim}}==
{{trans|Go}}
{{trans|Wren}}
<langsyntaxhighlight Nimlang="nim">import math, sequtils, strutils
 
type Matrix = seq[seq[float]]
Line 525 ⟶ 646:
echo " a * b = ", strassen(a, b).toString(10)
echo " c * d = ", strassen(c, d).toString(6)
echo " e * f = ", strassen(e, f).toString(10)</langsyntaxhighlight>
 
{{out}}
Line 540 ⟶ 661:
=={{header|Phix}}==
As noted on wp, you could pad with zeroes, and strip them on exit, instead of crashing for non-square 2<sup><small>n</small></sup> matrices.
<!--<syntaxhighlight lang="phix">(phixonline)-->
<lang Phix>function strassen(sequence a, b)
<span style="color: #008080;">with</span> <span style="color: #008080;">javascript_semantics</span>
integer l = length(a)
<span style="color: #008080;">function</span> <span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #004080;">sequence</span> <span style="color: #000000;">a</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">b</span><span style="color: #0000FF;">)</span>
if length(a[1])!=l
<span style="color: #004080;">integer</span> <span style="color: #000000;">l</span> <span style="color: #0000FF;">=</span> <span style="color: #7060A8;">length</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a</span><span style="color: #0000FF;">)</span>
or length(b)!=l
<span style="color: #008080;">if</span> <span style="color: #7060A8;">length</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a</span><span style="color: #0000FF;">[</span><span style="color: #000000;">1</span><span style="color: #0000FF;">])!=</span><span style="color: #000000;">l</span>
or length(b[1])!=l then
<span style="color: #008080;">or</span> <span style="color: #7060A8;">length</span><span style="color: #0000FF;">(</span><span style="color: #000000;">b</span><span style="color: #0000FF;">)!=</span><span style="color: #000000;">l</span>
crash("two equal square matrices only")
<span style="color: #008080;">or</span> <span style="color: #7060A8;">length</span><span style="color: #0000FF;">(</span><span style="color: #000000;">b</span><span style="color: #0000FF;">[</span><span style="color: #000000;">1</span><span style="color: #0000FF;">])!=</span><span style="color: #000000;">l</span> <span style="color: #008080;">then</span>
end if
<span style="color: #7060A8;">crash</span><span style="color: #0000FF;">(</span><span style="color: #008000;">"two equal square matrices only"</span><span style="color: #0000FF;">)</span>
if l=1 then return sq_mul(a,b) end if
<span style="color: #008080;">end</span> <span style="color: #008080;">if</span>
if remainder(l,1) then
<span style="color: #008080;">if</span> <span style="color: #000000;">l</span><span style="color: #0000FF;">=</span><span style="color: #000000;">1</span> <span style="color: #008080;">then</span> <span style="color: #008080;">return</span> <span style="color: #7060A8;">sq_mul</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b</span><span style="color: #0000FF;">)</span> <span style="color: #008080;">end</span> <span style="color: #008080;">if</span>
crash("2^n matrices only")
<span style="color: #008080;">if</span> <span style="color: #7060A8;">remainder</span><span style="color: #0000FF;">(</span><span style="color: #000000;">l</span><span style="color: #0000FF;">,</span><span style="color: #000000;">1</span><span style="color: #0000FF;">)</span> <span style="color: #008080;">then</span>
end if
<span style="color: #7060A8;">crash</span><span style="color: #0000FF;">(</span><span style="color: #008000;">"2^n matrices only"</span><span style="color: #0000FF;">)</span>
integer h = l/2
<span style="color: #008080;">end</span> <span style="color: #008080;">if</span>
sequence {a11,a12,a21,a22,b11,b12,b21,b22} @= repeat(repeat(0,h),h)
<span style="color: #004080;">integer</span> <span style="color: #000000;">h</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">l</span><span style="color: #0000FF;">/</span><span style="color: #000000;">2</span>
for i=1 to h do
<span style="color: #004080;">sequence</span> <span style="color: #0000FF;">{</span><span style="color: #000000;">a11</span><span style="color: #0000FF;">,</span><span style="color: #000000;">a12</span><span style="color: #0000FF;">,</span><span style="color: #000000;">a21</span><span style="color: #0000FF;">,</span><span style="color: #000000;">a22</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b11</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b12</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b21</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b22</span><span style="color: #0000FF;">}</span> <span style="color: #0000FF;">=</span> <span style="color: #7060A8;">repeat</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">repeat</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">repeat</span><span style="color: #0000FF;">(</span><span style="color: #000000;">0</span><span style="color: #0000FF;">,</span><span style="color: #000000;">h</span><span style="color: #0000FF;">),</span><span style="color: #000000;">h</span><span style="color: #0000FF;">),</span><span style="color: #000000;">8</span><span style="color: #0000FF;">)</span>
for j=1 to h do
<span style="color: #008080;">for</span> <span style="color: #000000;">i</span><span style="color: #0000FF;">=</span><span style="color: #000000;">1</span> <span style="color: #008080;">to</span> <span style="color: #000000;">h</span> <span style="color: #008080;">do</span>
a11[i][j] = a[i][j]
<span style="color: #008080;">for</span> <span style="color: #000000;">j</span><span style="color: #0000FF;">=</span><span style="color: #000000;">1</span> <span style="color: #008080;">to</span> <span style="color: #000000;">h</span> <span style="color: #008080;">do</span>
a12[i][j] = a[i][j+h]
<span style="color: #000000;">a11</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">a</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span>
a21[i][j] = a[i+h][j]
<span style="color: #000000;">a12</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">a</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">]</span>
a22[i][j] = a[i+h][j+h]
<span style="color: #000000;">a21</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">a</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span>
b11[i][j] = b[i][j]
<span style="color: #000000;">a22</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">a</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">]</span>
b12[i][j] = b[i][j+h]
<span style="color: #000000;">b11</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">b</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span>
b21[i][j] = b[i+h][j]
<span style="color: #000000;">b12</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">b</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">]</span>
b22[i][j] = b[i+h][j+h]
<span style="color: #000000;">b21</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">b</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span>
end for
<span style="color: #000000;">b22</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">b</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">]</span>
end for
<span style="color: #008080;">end</span> <span style="color: #008080;">for</span>
sequence p1 = strassen(sq_sub(a12,a22), sq_add(b21,b22)),
<span style="color: #008080;">end</span> <span style="color: #008080;">for</span>
p2 = strassen(sq_add(a11,a22), sq_add(b11,b22)),
<span style="color: #004080;">sequence</span> <span style="color: #000000;">p1</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">sq_sub</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a12</span><span style="color: #0000FF;">,</span><span style="color: #000000;">a22</span><span style="color: #0000FF;">),</span> <span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #000000;">b21</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b22</span><span style="color: #0000FF;">)),</span>
p3 = strassen(sq_sub(a11,a21), sq_add(b11,b12)),
<span style="color: #000000;">p2</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a11</span><span style="color: #0000FF;">,</span><span style="color: #000000;">a22</span><span style="color: #0000FF;">),</span> <span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #000000;">b11</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b22</span><span style="color: #0000FF;">)),</span>
p4 = strassen(sq_add(a11,a12), b22),
<span style="color: #000000;">p3</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">sq_sub</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a11</span><span style="color: #0000FF;">,</span><span style="color: #000000;">a21</span><span style="color: #0000FF;">),</span> <span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #000000;">b11</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b12</span><span style="color: #0000FF;">)),</span>
p5 = strassen(a11, sq_sub(b12,b22)),
<span style="color: #000000;">p4</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a11</span><span style="color: #0000FF;">,</span><span style="color: #000000;">a12</span><span style="color: #0000FF;">),</span> <span style="color: #000000;">b22</span><span style="color: #0000FF;">),</span>
p6 = strassen(a22, sq_sub(b21,b11)),
<span style="color: #000000;">p5</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a11</span><span style="color: #0000FF;">,</span> <span style="color: #7060A8;">sq_sub</span><span style="color: #0000FF;">(</span><span style="color: #000000;">b12</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b22</span><span style="color: #0000FF;">)),</span>
p7 = strassen(sq_add(a21,a22), b11),
<span style="color: #000000;">p6</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a22</span><span style="color: #0000FF;">,</span> <span style="color: #7060A8;">sq_sub</span><span style="color: #0000FF;">(</span><span style="color: #000000;">b21</span><span style="color: #0000FF;">,</span><span style="color: #000000;">b11</span><span style="color: #0000FF;">)),</span>
<span style="color: #000000;">p7</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #000000;">a21</span><span style="color: #0000FF;">,</span><span style="color: #000000;">a22</span><span style="color: #0000FF;">),</span> <span style="color: #000000;">b11</span><span style="color: #0000FF;">),</span>
c11 = sq_add(sq_sub(sq_add(p1,p2),p4),p6),
c12 = sq_add(p4,p5),
<span style="color: #000000;">c11</span> <span style="color: #0000FF;">=</span> <span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">sq_sub</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #000000;">p1</span><span style="color: #0000FF;">,</span><span style="color: #000000;">p2</span><span style="color: #0000FF;">),</span><span style="color: #000000;">p4</span><span style="color: #0000FF;">),</span><span style="color: #000000;">p6</span><span style="color: #0000FF;">),</span>
c21 = sq_add(p6,p7),
<span style="color: #000000;">c12</span> <span style="color: #0000FF;">=</span> <span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #000000;">p4</span><span style="color: #0000FF;">,</span><span style="color: #000000;">p5</span><span style="color: #0000FF;">),</span>
c22 = sq_sub(sq_add(sq_sub(p2,p3),p5),p7),
<span style="color: #000000;">c21</span> <span style="color: #0000FF;">=</span> <span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #000000;">p6</span><span style="color: #0000FF;">,</span><span style="color: #000000;">p7</span><span style="color: #0000FF;">),</span>
c = repeat(repeat(0,l),l)
<span style="color: #000000;">c22</span> <span style="color: #0000FF;">=</span> <span style="color: #7060A8;">sq_sub</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">sq_add</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">sq_sub</span><span style="color: #0000FF;">(</span><span style="color: #000000;">p2</span><span style="color: #0000FF;">,</span><span style="color: #000000;">p3</span><span style="color: #0000FF;">),</span><span style="color: #000000;">p5</span><span style="color: #0000FF;">),</span><span style="color: #000000;">p7</span><span style="color: #0000FF;">),</span>
for i=1 to h do
<span style="color: #000000;">c</span> <span style="color: #0000FF;">=</span> <span style="color: #7060A8;">repeat</span><span style="color: #0000FF;">(</span><span style="color: #7060A8;">repeat</span><span style="color: #0000FF;">(</span><span style="color: #000000;">0</span><span style="color: #0000FF;">,</span><span style="color: #000000;">l</span><span style="color: #0000FF;">),</span><span style="color: #000000;">l</span><span style="color: #0000FF;">)</span>
for j=1 to h do
<span style="color: #008080;">for</span> <span style="color: #000000;">i</span><span style="color: #0000FF;">=</span><span style="color: #000000;">1</span> <span style="color: #008080;">to</span> <span style="color: #000000;">h</span> <span style="color: #008080;">do</span>
c[i][j] = c11[i][j]
<span style="color: #008080;">for</span> <span style="color: #000000;">j</span><span style="color: #0000FF;">=</span><span style="color: #000000;">1</span> <span style="color: #008080;">to</span> <span style="color: #000000;">h</span> <span style="color: #008080;">do</span>
c[i][j+h] = c12[i][j]
<span style="color: #000000;">c</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">c11</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span>
c[i+h][j] = c21[i][j]
<span style="color: #000000;">c</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">c12</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span>
c[i+h][j+h] = c22[i][j]
<span style="color: #000000;">c</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">c21</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span>
end for
<span style="color: #000000;">c</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">+</span><span style="color: #000000;">h</span><span style="color: #0000FF;">]</span> <span style="color: #0000FF;">=</span> <span style="color: #000000;">c22</span><span style="color: #0000FF;">[</span><span style="color: #000000;">i</span><span style="color: #0000FF;">][</span><span style="color: #000000;">j</span><span style="color: #0000FF;">]</span>
end for
<span style="color: #008080;">end</span> <span style="color: #008080;">for</span>
return c
<span style="color: #008080;">end</span> <span style="color: #008080;">for</span>
end function
<span style="color: #008080;">return</span> <span style="color: #000000;">c</span>
 
<span style="color: #008080;">end</span> <span style="color: #008080;">function</span>
ppOpt({pp_Nest,1,pp_IntFmt,"%3d",pp_FltFmt,"%3.0f",pp_IntCh,false})
 
<span style="color: #7060A8;">ppOpt</span><span style="color: #0000FF;">({</span><span style="color: #004600;">pp_Nest</span><span style="color: #0000FF;">,</span><span style="color: #000000;">1</span><span style="color: #0000FF;">,</span><span style="color: #004600;">pp_IntFmt</span><span style="color: #0000FF;">,</span><span style="color: #008000;">"%3d"</span><span style="color: #0000FF;">,</span><span style="color: #004600;">pp_FltFmt</span><span style="color: #0000FF;">,</span><span style="color: #008000;">"%3.0f"</span><span style="color: #0000FF;">,</span><span style="color: #004600;">pp_IntCh</span><span style="color: #0000FF;">,</span><span style="color: #004600;">false</span><span style="color: #0000FF;">})</span>
constant A = {{1,2},
{3,4}},
<span style="color: #008080;">constant</span> <span style="color: #000000;">A</span> <span style="color: #0000FF;">=</span> <span style="color: #0000FF;">{{</span><span style="color: #000000;">1</span><span style="color: #0000FF;">,</span><span style="color: #000000;">2</span><span style="color: #0000FF;">},</span>
B = {{5,6},
<span style="color: #0000FF;">{</span><span style="color: #000000;">3</span><span style="color: #0000FF;">,</span><span style="color: #000000;">4</span><span style="color: #0000FF;">}},</span>
{7,8}}
<span style="color: #000000;">B</span> <span style="color: #0000FF;">=</span> <span style="color: #0000FF;">{{</span><span style="color: #000000;">5</span><span style="color: #0000FF;">,</span><span style="color: #000000;">6</span><span style="color: #0000FF;">},</span>
pp(strassen(A,B))
<span style="color: #0000FF;">{</span><span style="color: #000000;">7</span><span style="color: #0000FF;">,</span><span style="color: #000000;">8</span><span style="color: #0000FF;">}}</span>
 
<span style="color: #7060A8;">pp</span><span style="color: #0000FF;">(</span><span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #000000;">A</span><span style="color: #0000FF;">,</span><span style="color: #000000;">B</span><span style="color: #0000FF;">))</span>
constant C = { { 1, 1, 1, 1 },
{ 2, 4, 8, 16 },
<span style="color: #008080;">constant</span> <span style="color: #000000;">C</span> <span style="color: #0000FF;">=</span> <span style="color: #0000FF;">{</span> <span style="color: #0000FF;">{</span> <span style="color: #000000;">1</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">1</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">1</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">1</span> <span style="color: #0000FF;">},</span>
{ 3, 9, 27, 81 },
<span style="color: #0000FF;">{</span> <span style="color: #000000;">2</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">4</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">8</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">16</span> <span style="color: #0000FF;">},</span>
{ 4, 16, 64, 256 }},
<span style="color: #0000FF;">{</span> <span style="color: #000000;">3</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">9</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">27</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">81</span> <span style="color: #0000FF;">},</span>
D = { { 4, -3, 4/3, -1/ 4 },
<span style="color: #0000FF;">{</span> <span style="color: #000000;">4</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">16</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">64</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">256</span> <span style="color: #0000FF;">}},</span>
{-13/3, 19/4, -7/3, 11/24 },
<span style="color: #000000;">D</span> <span style="color: #0000FF;">=</span> <span style="color: #0000FF;">{</span> <span style="color: #0000FF;">{</span> <span style="color: #000000;">4</span><span style="color: #0000FF;">,</span> <span style="color: #0000FF;">-</span><span style="color: #000000;">3</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">4</span><span style="color: #0000FF;">/</span><span style="color: #000000;">3</span><span style="color: #0000FF;">,</span> <span style="color: #0000FF;">-</span><span style="color: #000000;">1</span><span style="color: #0000FF;">/</span> <span style="color: #000000;">4</span> <span style="color: #0000FF;">},</span>
{ 3/2, -2, 7/6, -1/ 4 },
<span style="color: #0000FF;">{-</span><span style="color: #000000;">13</span><span style="color: #0000FF;">/</span><span style="color: #000000;">3</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">19</span><span style="color: #0000FF;">/</span><span style="color: #000000;">4</span><span style="color: #0000FF;">,</span> <span style="color: #0000FF;">-</span><span style="color: #000000;">7</span><span style="color: #0000FF;">/</span><span style="color: #000000;">3</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">11</span><span style="color: #0000FF;">/</span><span style="color: #000000;">24</span> <span style="color: #0000FF;">},</span>
{ -1/6, 1/4, -1/6, 1/24 }}
<span style="color: #0000FF;">{</span> <span style="color: #000000;">3</span><span style="color: #0000FF;">/</span><span style="color: #000000;">2</span><span style="color: #0000FF;">,</span> <span style="color: #0000FF;">-</span><span style="color: #000000;">2</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">7</span><span style="color: #0000FF;">/</span><span style="color: #000000;">6</span><span style="color: #0000FF;">,</span> <span style="color: #0000FF;">-</span><span style="color: #000000;">1</span><span style="color: #0000FF;">/</span> <span style="color: #000000;">4</span> <span style="color: #0000FF;">},</span>
pp(strassen(C,D))
<span style="color: #0000FF;">{</span> <span style="color: #0000FF;">-</span><span style="color: #000000;">1</span><span style="color: #0000FF;">/</span><span style="color: #000000;">6</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">1</span><span style="color: #0000FF;">/</span><span style="color: #000000;">4</span><span style="color: #0000FF;">,</span> <span style="color: #0000FF;">-</span><span style="color: #000000;">1</span><span style="color: #0000FF;">/</span><span style="color: #000000;">6</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">1</span><span style="color: #0000FF;">/</span><span style="color: #000000;">24</span> <span style="color: #0000FF;">}}</span>
 
<span style="color: #7060A8;">pp</span><span style="color: #0000FF;">(</span><span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #000000;">C</span><span style="color: #0000FF;">,</span><span style="color: #000000;">D</span><span style="color: #0000FF;">))</span>
constant E = {{ 1, 2, 3, 4},
{ 5, 6, 7, 8},
<span style="color: #008080;">constant</span> <span style="color: #000000;">F</span> <span style="color: #0000FF;">=</span> <span style="color: #0000FF;">{{</span> <span style="color: #000000;">1</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">2</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">3</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">4</span><span style="color: #0000FF;">},</span>
{ 9,10,11,12},
<span style="color: #0000FF;">{</span> <span style="color: #000000;">5</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">6</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">7</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">8</span><span style="color: #0000FF;">},</span>
{13,14,15,16}},
<span style="color: #0000FF;">{</span> <span style="color: #000000;">9</span><span style="color: #0000FF;">,</span><span style="color: #000000;">10</span><span style="color: #0000FF;">,</span><span style="color: #000000;">11</span><span style="color: #0000FF;">,</span><span style="color: #000000;">12</span><span style="color: #0000FF;">},</span>
F = {{1, 0, 0, 0},
<span style="color: #0000FF;">{</span><span style="color: #000000;">13</span><span style="color: #0000FF;">,</span><span style="color: #000000;">14</span><span style="color: #0000FF;">,</span><span style="color: #000000;">15</span><span style="color: #0000FF;">,</span><span style="color: #000000;">16</span><span style="color: #0000FF;">}},</span>
{0, 1, 0, 0},
<span style="color: #000000;">G</span> <span style="color: #0000FF;">=</span> <span style="color: #0000FF;">{{</span><span style="color: #000000;">1</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">0</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">0</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">0</span><span style="color: #0000FF;">},</span>
{0, 0, 1, 0},
<span style="color: #0000FF;">{</span><span style="color: #000000;">0</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">1</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">0</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">0</span><span style="color: #0000FF;">},</span>
{0, 0, 0, 1}}
<span style="color: #0000FF;">{</span><span style="color: #000000;">0</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">0</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">1</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">0</span><span style="color: #0000FF;">},</span>
pp(strassen(E,F))
<span style="color: #0000FF;">{</span><span style="color: #000000;">0</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">0</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">0</span><span style="color: #0000FF;">,</span> <span style="color: #000000;">1</span><span style="color: #0000FF;">}}</span>
 
<span style="color: #7060A8;">pp</span><span style="color: #0000FF;">(</span><span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #000000;">F</span><span style="color: #0000FF;">,</span><span style="color: #000000;">G</span><span style="color: #0000FF;">))</span>
constant r = sqrt(2)/2,
R = {{ r,r},
<span style="color: #008080;">constant</span> <span style="color: #000000;">r</span> <span style="color: #0000FF;">=</span> <span style="color: #7060A8;">sqrt</span><span style="color: #0000FF;">(</span><span style="color: #000000;">2</span><span style="color: #0000FF;">)/</span><span style="color: #000000;">2</span><span style="color: #0000FF;">,</span>
{-r,r}}
<span style="color: #000000;">R</span> <span style="color: #0000FF;">=</span> <span style="color: #0000FF;">{{</span> <span style="color: #000000;">r</span><span style="color: #0000FF;">,</span><span style="color: #000000;">r</span><span style="color: #0000FF;">},</span>
pp(strassen(R,R))</lang>
<span style="color: #0000FF;">{-</span><span style="color: #000000;">r</span><span style="color: #0000FF;">,</span><span style="color: #000000;">r</span><span style="color: #0000FF;">}}</span>
<span style="color: #7060A8;">pp</span><span style="color: #0000FF;">(</span><span style="color: #000000;">strassen</span><span style="color: #0000FF;">(</span><span style="color: #000000;">R</span><span style="color: #0000FF;">,</span><span style="color: #000000;">R</span><span style="color: #0000FF;">))</span>
<!--</syntaxhighlight>-->
{{out}}
Matches that of [[Matrix_multiplication#Phix]], when given the same inputs. Note that a few "-0" show up in the second one (the identity matrix) under pwa/p2js.
<pre>
{{ 19, 22},
Line 636 ⟶ 760:
{{ 0, 1},
{ -1, 0}}
</pre>
 
=={{header|Python}}==
<syntaxhighlight lang="python">"""Matrix multiplication using Strassen's algorithm. Requires Python >= 3.7."""
 
from __future__ import annotations
from itertools import chain
from typing import List
from typing import NamedTuple
from typing import Optional
 
 
class Shape(NamedTuple):
rows: int
cols: int
 
 
class Matrix(List):
"""A matrix implemented as a two-dimensional list."""
 
@classmethod
def block(cls, blocks) -> Matrix:
"""Return a new Matrix assembled from nested blocks."""
m = Matrix()
for hblock in blocks:
for row in zip(*hblock):
m.append(list(chain.from_iterable(row)))
 
return m
 
def dot(self, b: Matrix) -> Matrix:
"""Return a new Matrix that is the product of this matrix and matrix `b`.
Uses 'simple' or 'naive' matrix multiplication."""
assert self.shape.cols == b.shape.rows
m = Matrix()
for row in self:
new_row = []
for c in range(len(b[0])):
col = [b[r][c] for r in range(len(b))]
new_row.append(sum(x * y for x, y in zip(row, col)))
m.append(new_row)
return m
 
def __matmul__(self, b: Matrix) -> Matrix:
return self.dot(b)
 
def __add__(self, b: Matrix) -> Matrix:
"""Matrix addition."""
assert self.shape == b.shape
rows, cols = self.shape
return Matrix(
[[self[i][j] + b[i][j] for j in range(cols)] for i in range(rows)]
)
 
def __sub__(self, b: Matrix) -> Matrix:
"""Matrix subtraction."""
assert self.shape == b.shape
rows, cols = self.shape
return Matrix(
[[self[i][j] - b[i][j] for j in range(cols)] for i in range(rows)]
)
 
def strassen(self, b: Matrix) -> Matrix:
"""Return a new Matrix that is the product of this matrix and matrix `b`.
Uses strassen algorithm."""
rows, cols = self.shape
 
assert rows == cols, "matrices must be square"
assert self.shape == b.shape, "matrices must be the same shape"
assert rows and (rows & rows - 1) == 0, "shape must be a power of 2"
 
if rows == 1:
return self.dot(b)
 
p = rows // 2 # partition
 
a11 = Matrix([n[:p] for n in self[:p]])
a12 = Matrix([n[p:] for n in self[:p]])
a21 = Matrix([n[:p] for n in self[p:]])
a22 = Matrix([n[p:] for n in self[p:]])
 
b11 = Matrix([n[:p] for n in b[:p]])
b12 = Matrix([n[p:] for n in b[:p]])
b21 = Matrix([n[:p] for n in b[p:]])
b22 = Matrix([n[p:] for n in b[p:]])
 
m1 = (a11 + a22).strassen(b11 + b22)
m2 = (a21 + a22).strassen(b11)
m3 = a11.strassen(b12 - b22)
m4 = a22.strassen(b21 - b11)
m5 = (a11 + a12).strassen(b22)
m6 = (a21 - a11).strassen(b11 + b12)
m7 = (a12 - a22).strassen(b21 + b22)
 
c11 = m1 + m4 - m5 + m7
c12 = m3 + m5
c21 = m2 + m4
c22 = m1 - m2 + m3 + m6
 
return Matrix.block([[c11, c12], [c21, c22]])
 
def round(self, ndigits: Optional[int] = None) -> Matrix:
return Matrix([[round(i, ndigits) for i in row] for row in self])
 
@property
def shape(self) -> Shape:
cols = len(self[0]) if self else 0
return Shape(len(self), cols)
 
 
def examples():
a = Matrix(
[
[1, 2],
[3, 4],
]
)
b = Matrix(
[
[5, 6],
[7, 8],
]
)
c = Matrix(
[
[1, 1, 1, 1],
[2, 4, 8, 16],
[3, 9, 27, 81],
[4, 16, 64, 256],
]
)
d = Matrix(
[
[4, -3, 4 / 3, -1 / 4],
[-13 / 3, 19 / 4, -7 / 3, 11 / 24],
[3 / 2, -2, 7 / 6, -1 / 4],
[-1 / 6, 1 / 4, -1 / 6, 1 / 24],
]
)
e = Matrix(
[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]
)
f = Matrix(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
)
 
print("Naive matrix multiplication:")
print(f" a * b = {a @ b}")
print(f" c * d = {(c @ d).round()}")
print(f" e * f = {e @ f}")
 
print("Strassen's matrix multiplication:")
print(f" a * b = {a.strassen(b)}")
print(f" c * d = {c.strassen(d).round()}")
print(f" e * f = {e.strassen(f)}")
 
 
if __name__ == "__main__":
examples()
</syntaxhighlight>
 
{{out}}
<pre>
Naive matrix multiplication:
a * b = [[19, 22], [43, 50]]
c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
Strassen's matrix multiplication:
a * b = [[19, 22], [43, 50]]
c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
</pre>
 
Line 641 ⟶ 946:
Special thanks go to the module author, [https://github.com/frithnanth Fernando Santagata], on showing how to deal with a pass-by-value case.
{{trans|Julia}}
<syntaxhighlight lang="raku" perl6line># 20210126 Raku programming solution
 
use Math::Libgsl::Constants;
Line 734 ⟶ 1,039:
1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16
1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1
=end code</langsyntaxhighlight>
{{out}}
<pre>Regular multiply:
Line 762 ⟶ 1,067:
9 10 11 12
13 14 15 16</pre>
 
 
=={{header|Scala}}==
{{trans|Go}}
<syntaxhighlight lang="Scala">
import scala.math
 
object MatrixOperations {
 
type Matrix = Array[Array[Double]]
 
implicit class RichMatrix(val m: Matrix) {
def rows: Int = m.length
def cols: Int = m(0).length
 
def add(m2: Matrix): Matrix = {
require(
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
)
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) + m2(i)(j))
}
 
def sub(m2: Matrix): Matrix = {
require(
m.rows == m2.rows && m.cols == m2.cols,
"Matrices must have the same dimensions."
)
Array.tabulate(m.rows, m.cols)((i, j) => m(i)(j) - m2(i)(j))
}
 
def mul(m2: Matrix): Matrix = {
require(m.cols == m2.rows, "Cannot multiply these matrices.")
Array.tabulate(m.rows, m2.cols)((i, j) =>
(0 until m.cols).map(k => m(i)(k) * m2(k)(j)).sum
)
}
 
def toString(p: Int): String = {
val pow = math.pow(10, p)
m.map(row =>
row
.map(value => (math.round(value * pow) / pow).toString)
.mkString("[", ", ", "]")
).mkString("[", ",\n ", "]")
}
}
 
def toQuarters(m: Matrix): Array[Matrix] = {
val r = m.rows / 2
val c = m.cols / 2
val p = params(r, c)
(0 until 4).map { k =>
Array.tabulate(r, c)((i, j) => m(p(k)(0) + i)(p(k)(2) + j))
}.toArray
}
 
def fromQuarters(q: Array[Matrix]): Matrix = {
val r = q(0).rows
val c = q(0).cols
val p = params(r, c)
Array.tabulate(r * 2, c * 2)((i, j) => q((i / r) * 2 + j / c)(i % r)(j % c))
}
 
def strassen(a: Matrix, b: Matrix): Matrix = {
require(
a.rows == a.cols && b.rows == b.cols && a.rows == b.rows,
"Matrices must be square and of equal size."
)
require(
a.rows != 0 && (a.rows & (a.rows - 1)) == 0,
"Size of matrices must be a power of two."
)
 
if (a.rows == 1) {
return a.mul(b)
}
 
val qa = toQuarters(a)
val qb = toQuarters(b)
 
val p1 = strassen(qa(1).sub(qa(3)), qb(2).add(qb(3)))
val p2 = strassen(qa(0).add(qa(3)), qb(0).add(qb(3)))
val p3 = strassen(qa(0).sub(qa(2)), qb(0).add(qb(1)))
val p4 = strassen(qa(0).add(qa(1)), qb(3))
val p5 = strassen(qa(0), qb(1).sub(qb(3)))
val p6 = strassen(qa(3), qb(2).sub(qb(0)))
val p7 = strassen(qa(2).add(qa(3)), qb(0))
 
val q = Array(
p1.add(p2).sub(p4).add(p6),
p4.add(p5),
p6.add(p7),
p2.sub(p3).add(p5).sub(p7)
)
 
fromQuarters(q)
}
 
private def params(r: Int, c: Int): Array[Array[Int]] = {
Array(
Array(0, r, 0, c, 0, 0),
Array(0, r, c, 2 * c, 0, c),
Array(r, 2 * r, 0, c, r, 0),
Array(r, 2 * r, c, 2 * c, r, c)
)
}
 
def main(args: Array[String]): Unit = {
val a: Matrix = Array(Array(1.0, 2.0), Array(3.0, 4.0))
val b: Matrix = Array(Array(5.0, 6.0), Array(7.0, 8.0))
val c: Matrix = Array(
Array(1.0, 1.0, 1.0, 1.0),
Array(2.0, 4.0, 8.0, 16.0),
Array(3.0, 9.0, 27.0, 81.0),
Array(4.0, 16.0, 64.0, 256.0)
)
val d: Matrix = Array(
Array(4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0),
Array(-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0),
Array(3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0),
Array(-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0)
)
val e: Matrix = Array(
Array(1.0, 2.0, 3.0, 4.0),
Array(5.0, 6.0, 7.0, 8.0),
Array(9.0, 10.0, 11.0, 12.0),
Array(13.0, 14.0, 15.0, 16.0)
)
val f: Matrix = Array(
Array(1.0, 0.0, 0.0, 0.0),
Array(0.0, 1.0, 0.0, 0.0),
Array(0.0, 0.0, 1.0, 0.0),
Array(0.0, 0.0, 0.0, 1.0)
)
 
println("Using 'normal' matrix multiplication:")
println(
s" a * b = ${a.mul(b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
println(s" c * d = ${c.mul(d).toString(6)}")
println(
s" e * f = ${e.mul(f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
 
println("\nUsing 'Strassen' matrix multiplication:")
println(
s" a * b = ${strassen(a, b).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
println(s" c * d = ${strassen(c, d).toString(6)}")
println(
s" e * f = ${strassen(e, f).map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}"
)
}
}
</syntaxhighlight>
{{out}}
<pre>
Using 'normal' matrix multiplication:
a * b = [[19.0, 22.0], [43.0, 50.0]]
c * d = [[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]]
e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]
 
Using 'Strassen' matrix multiplication:
a * b = [[19.0, 22.0], [43.0, 50.0]]
c * d = [[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]]
e * f = [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]
 
</pre>
 
=={{header|Swift}}==
[https://github.com/hollance/Matrix/blob/master/Matrix.swift '''Matrix Class'''] by [https://github.com/ozgurshn ozgurshn]
<syntaxhighlight lang="swift">
// Matrix Strassen Multiplication
func strassenMultiply(matrix1: Matrix, matrix2: Matrix) -> Matrix {
precondition(matrix1.columns == matrix2.columns,
"Two matrices can only be matrix multiplied if one has dimensions mxn & the other has dimensions nxp where m, n, p are in R")
 
// Transform to square matrix
let maxColumns = Swift.max(matrix1.rows, matrix1.columns, matrix2.rows, matrix2.columns)
let pwr2 = nextPowerOfTwo(num: maxColumns)
var sqrMatrix1 = Matrix(rows: pwr2, columns: pwr2)
var sqrMatrix2 = Matrix(rows: pwr2, columns: pwr2)
// fill square matrix 1 with values
for i in 0..<matrix1.rows {
for j in 0..<matrix1.columns{
sqrMatrix1[i, j] = matrix1[i, j]
}
}
// fill square matrix 2 with values
for i in 0..<matrix2.rows {
for j in 0..<matrix2.columns{
sqrMatrix2[i, j] = matrix2[i, j]
}
}
// Get strassen result and transfer to array with proper size
let formulaResult = strassenFormula(matrix1: sqrMatrix1, matrix2: sqrMatrix2)
var finalResult = Matrix(rows: matrix1.rows, columns: matrix2.columns)
for i in 0..<finalResult.rows{
for j in 0..<finalResult.columns {
finalResult[i, j] = formulaResult[i, j]
}
}
return finalResult
}
 
// Calculate next power of 2
func nextPowerOfTwo(num: Int) -> Int {
// formula for next power of 2
return Int(pow(2,(ceil(log2(Double(num))))))
}
 
// Multiply Matrices Using Strassen Formula
func strassenFormula(matrix1: Matrix, matrix2: Matrix) -> Matrix {
precondition(matrix1.rows == matrix1.columns && matrix2.rows == matrix2.columns, "Matrices need to be square")
guard matrix1.rows > 1 && matrix2.rows > 1 else { return matrix1 * matrix2 }
 
let rowHalf = matrix1.rows / 2
// Strassen Formula https://www.geeksforgeeks.org/easy-way-remember-strassens-matrix-equation/
// p1 = a(f-h) p2 = (a+b)h
// p2 = (c+d)e p4 = d(g-e)
// p5 = (a+d)(e+h) p6 = (b-d)(g+h)
// p7 = (a-c)(e+f)
|a b| x |e f| = |(p5+p4-p2+p6) (p1+p2)|
|c d| |g h| |(p3+p4) (p1+p5-p3-p7)|
Matrix 1 Matrix 2 Result
 
// create empty matrices for a, b, c, d, e, f, g, h
var a = Matrix(rows: rowHalf, columns: rowHalf)
var b = Matrix(rows: rowHalf, columns: rowHalf)
var c = Matrix(rows: rowHalf, columns: rowHalf)
var d = Matrix(rows: rowHalf, columns: rowHalf)
var e = Matrix(rows: rowHalf, columns: rowHalf)
var f = Matrix(rows: rowHalf, columns: rowHalf)
var g = Matrix(rows: rowHalf, columns: rowHalf)
var h = Matrix(rows: rowHalf, columns: rowHalf)
// fill the matrices with values
for i in 0..<rowHalf {
for j in 0..<rowHalf {
a[i, j] = matrix1[i, j]
b[i, j] = matrix1[i, j+rowHalf]
c[i, j] = matrix1[i+rowHalf, j]
d[i, j] = matrix1[i+rowHalf, j+rowHalf]
e[i, j] = matrix2[i, j]
f[i, j] = matrix2[i, j+rowHalf]
g[i, j] = matrix2[i+rowHalf, j]
h[i, j] = matrix2[i+rowHalf, j+rowHalf]
}
}
// a * (f - h)
let p1 = strassenFormula(matrix1: a, matrix2: (f - h))
// (a + b) * h
let p2 = strassenFormula(matrix1: (a + b), matrix2: h)
// (c + d) * e
let p3 = strassenFormula(matrix1: (c + d), matrix2: e)
// d * (g - e)
let p4 = strassenFormula(matrix1: d, matrix2: (g - e))
// (a + d) * (e + h)
let p5 = strassenFormula(matrix1: (a + d), matrix2: (e + h))
// (b - d) * (g + h)
let p6 = strassenFormula(matrix1: (b - d), matrix2: (g + h))
// (a - c) * (e + f)
let p7 = strassenFormula(matrix1: (a - c), matrix2: (e + f))
// p5 + p4 - p2 + p6
let result11 = p5 + p4 - p2 + p6
// p1 + p2
let result12 = p1 + p2
// p3 + p4
let result21 = p3 + p4
// p1 + p5 - p3 - p7
let result22 = p1 + p5 - p3 - p7
 
// create an empty matrix for result and fill with values
var result = Matrix(rows: matrix1.rows, columns: matrix1.rows)
for i in 0..<rowHalf {
for j in 0..<rowHalf {
result[i, j] = result11[i, j]
result[i, j+rowHalf] = result12[i, j]
result[i+rowHalf, j] = result21[i, j]
result[i+rowHalf, j+rowHalf] = result22[i, j]
}
}
 
return result
}
 
func main(){
// Matrix Class https://github.com/hollance/Matrix/blob/master/Matrix.swift
var a = Matrix(rows: 2, columns: 2)
a[row: 0] = [1, 2]
a[row: 1] = [3, 4]
 
var b = Matrix(rows: 2, columns: 2)
b[row: 0] = [5, 6]
b[row: 1] = [7, 8]
var c = Matrix(rows: 4, columns: 4)
c[row: 0] = [1, 1, 1,1]
c[row: 1] = [2, 4, 8, 16]
c[row: 2] = [3, 9, 27, 81]
c[row: 3] = [4, 16, 64, 256]
var d = Matrix(rows: 4, columns: 4)
d[row: 0] = [4, -3, Double(4/3), Double(-1/4)]
d[row: 1] = [Double(-13/3), Double(19/4), Double(-7/3), Double(11/24)]
d[row: 2] = [Double(3/2), Double(-2), Double(7/6), Double(-1/4)]
d[row: 3] = [Double(-1/6), Double(1/4), Double(-1/6), Double(1/24)]
var e = Matrix(rows: 4, columns: 4)
e[row: 0] = [1, 2, 3, 4]
e[row: 1] = [5, 6, 7, 8]
e[row: 2] = [9, 10, 11, 12]
e[row: 3] = [13, 14, 15, 16]
var f = Matrix(rows: 4, columns: 4)
f[row: 0] = [1, 0, 0, 0]
f[row: 1] = [0, 1, 0 ,0]
f[row: 2] = [0 ,0 ,1, 0]
f[row: 3] = [0, 0 ,0 ,1]
let result1 = strassenMultiply(matrix1: a, matrix2: b)
print("AxB")
print(result1.description)
let result2 = strassenMultiply(matrix1: c, matrix2: d)
print("CxD")
print(result2.description)
let result3 = strassenMultiply(matrix1: e, matrix2: f)
print("ExF")
print(result3.description)
}
main()</syntaxhighlight>
 
{{out}}
<pre>
AxB
19 22
43 50
 
CxD
1 -1 0 0
0 -6 2 0
3 -27 12 0
16 -76 36 0
 
ExF
 
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
</pre>
 
=={{header|Wren}}==
{{libheader|Wren-fmt}}
Wren doesn't currently have a matrix module so I've written a rudimentary Matrix class with sufficient functionality to complete this task.
 
I've used the Phix entry's examples to test the Strassen algorithm implementation.
<syntaxhighlight lang="wren">class Matrix {
<lang ecmascript>import "/fmt" for Fmt
 
class Matrix {
construct new(a) {
if (a.type != List || a.count == 0 || a[0].type != List || a[0].count == 0 || a[0][0].type != Num) {
Line 923 ⟶ 1,587:
System.print(" a * b = %(strassen.call(a, b))")
System.print(" c * d = %(strassen.call(c, d).toString(6))")
System.print(" e * f = %(strassen.call(e, f))")</langsyntaxhighlight>
 
{{out}}
Line 937 ⟶ 1,601:
e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
</pre>
<br>
{{libheader|Wren-matrix}}
Since the above version was written, a Matrix module has been added and the following version uses it. The output is exactly the same as before.
<syntaxhighlight lang="wren">import "./matrix" for Matrix
var params = Fn.new { |r, c|
return [
[0...r, 0...c, 0, 0],
[0...r, c...2*c, 0, c],
[r...2*r, 0...c, r, 0],
[r...2*r, c...2*c, r, c]
]
}
var toQuarters = Fn.new { |m|
var r = (m.numRows/2).floor
var c = (m.numCols/2).floor
var p = params.call(r, c)
var quarters = []
for (k in 0..3) {
var q = List.filled(r, null)
for (i in p[k][0]) {
q[i - p[k][2]] = List.filled(c, 0)
for (j in p[k][1]) q[i - p[k][2]][j - p[k][3]] = m[i, j]
}
quarters.add(Matrix.new(q))
}
return quarters
}
var fromQuarters = Fn.new { |q|
var r = q[0].numRows
var c = q[0].numCols
var p = params.call(r, c)
r = r * 2
c = c * 2
var m = List.filled(r, null)
for (i in 0...c) m[i] = List.filled(c, 0)
for (k in 0..3) {
for (i in p[k][0]) {
for (j in p[k][1]) m[i][j] = q[k][i - p[k][2], j - p[k][3]]
}
}
return Matrix.new(m)
}
var strassen // recursive
strassen = Fn.new { |a, b|
if (!a.isSquare || !b.isSquare || !a.sameSize(b)) {
Fiber.abort("Matrices must be square and of equal size.")
}
if (a.numRows == 0 || (a.numRows & (a.numRows - 1)) != 0) {
Fiber.abort("Size of matrices must be a power of two.")
}
if (a.numRows == 1) return a * b
var qa = toQuarters.call(a)
var qb = toQuarters.call(b)
var p1 = strassen.call(qa[1] - qa[3], qb[2] + qb[3])
var p2 = strassen.call(qa[0] + qa[3], qb[0] + qb[3])
var p3 = strassen.call(qa[0] - qa[2], qb[0] + qb[1])
var p4 = strassen.call(qa[0] + qa[1], qb[3])
var p5 = strassen.call(qa[0], qb[1] - qb[3])
var p6 = strassen.call(qa[3], qb[2] - qb[0])
var p7 = strassen.call(qa[2] + qa[3], qb[0])
var q = List.filled(4, null)
q[0] = p1 + p2 - p4 + p6
q[1] = p4 + p5
q[2] = p6 + p7
q[3] = p2 - p3 + p5 - p7
return fromQuarters.call(q)
}
var a = Matrix.new([ [1,2], [3, 4] ])
var b = Matrix.new([ [5,6], [7, 8] ])
var c = Matrix.new([ [1, 1, 1, 1], [2, 4, 8, 16], [3, 9, 27, 81], [4, 16, 64, 256] ])
var d = Matrix.new([ [4, -3, 4/3, -1/4], [-13/3, 19/4, -7/3, 11/24],
[3/2, -2, 7/6, -1/4], [-1/6, 1/4, -1/6, 1/24] ])
var e = Matrix.new([ [1, 2, 3, 4], [5, 6, 7, 8], [9,10,11,12], [13,14,15,16] ])
var f = Matrix.new([ [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1] ])
System.print("Using 'normal' matrix multiplication:")
System.print(" a * b = %(a * b)")
System.print(" c * d = %((c * d).toString(6))")
System.print(" e * f = %(e * f)")
System.print("\nUsing 'Strassen' matrix multiplication:")
System.print(" a * b = %(strassen.call(a, b))")
System.print(" c * d = %(strassen.call(c, d).toString(6))")
System.print(" e * f = %(strassen.call(e, f))")</syntaxhighlight>
9,476

edits