Strassen's algorithm: Difference between revisions
Content added Content deleted
m (promoted to (full) task status, added highligting and whitespace.) |
(A Python implementation) |
||
Line 642: | Line 642: | ||
{{ 0, 1}, |
{{ 0, 1}, |
||
{ -1, 0}} |
{ -1, 0}} |
||
</pre> |
|||
=={{header|Python}}== |
|||
<lang python>"""Matrix multiplication using Strassen's algorithm. Requires Python >= 3.7.""" |
|||
from __future__ import annotations |
|||
from itertools import chain |
|||
from typing import List |
|||
from typing import NamedTuple |
|||
from typing import Optional |
|||
class Shape(NamedTuple): |
|||
rows: int |
|||
cols: int |
|||
class Matrix(List): |
|||
"""A matrix implemented as a two-dimensional list.""" |
|||
@classmethod |
|||
def block(cls, blocks) -> Matrix: |
|||
"""Return a new Matrix assembled from nested blocks.""" |
|||
m = Matrix() |
|||
for hblock in blocks: |
|||
for row in zip(*hblock): |
|||
m.append(list(chain.from_iterable(row))) |
|||
return m |
|||
def dot(self, b: Matrix) -> Matrix: |
|||
"""Return a new Matrix that is the product of this matrix and matrix `b`. |
|||
Uses 'simple' or 'naive' matrix multiplication.""" |
|||
assert self.shape.cols == b.shape.rows |
|||
m = Matrix() |
|||
for row in self: |
|||
new_row = [] |
|||
for c in range(len(b[0])): |
|||
col = [b[r][c] for r in range(len(b))] |
|||
new_row.append(sum(x * y for x, y in zip(row, col))) |
|||
m.append(new_row) |
|||
return m |
|||
def __matmul__(self, b: Matrix) -> Matrix: |
|||
return self.dot(b) |
|||
def __add__(self, b: Matrix) -> Matrix: |
|||
"""Matrix addition.""" |
|||
assert self.shape == b.shape |
|||
rows, cols = self.shape |
|||
return Matrix( |
|||
[[self[i][j] + b[i][j] for j in range(cols)] for i in range(rows)] |
|||
) |
|||
def __sub__(self, b: Matrix) -> Matrix: |
|||
"""Matrix subtraction.""" |
|||
assert self.shape == b.shape |
|||
rows, cols = self.shape |
|||
return Matrix( |
|||
[[self[i][j] - b[i][j] for j in range(cols)] for i in range(rows)] |
|||
) |
|||
def strassen(self, b: Matrix) -> Matrix: |
|||
"""Return a new Matrix that is the product of this matrix and matrix `b`. |
|||
Uses strassen algorithm.""" |
|||
rows, cols = self.shape |
|||
assert rows == cols, "matrices must be square" |
|||
assert self.shape == b.shape, "matrices must be the same shape" |
|||
assert rows and (rows & rows - 1) == 0, "shape must be a power of 2" |
|||
if rows == 1: |
|||
return self.dot(b) |
|||
p = rows // 2 # partition |
|||
a11 = Matrix([n[:p] for n in self[:p]]) |
|||
a12 = Matrix([n[p:] for n in self[:p]]) |
|||
a21 = Matrix([n[:p] for n in self[p:]]) |
|||
a22 = Matrix([n[p:] for n in self[p:]]) |
|||
b11 = Matrix([n[:p] for n in b[:p]]) |
|||
b12 = Matrix([n[p:] for n in b[:p]]) |
|||
b21 = Matrix([n[:p] for n in b[p:]]) |
|||
b22 = Matrix([n[p:] for n in b[p:]]) |
|||
m1 = (a11 + a22).strassen(b11 + b22) |
|||
m2 = (a21 + a22).strassen(b11) |
|||
m3 = a11.strassen(b12 - b22) |
|||
m4 = a22.strassen(b21 - b11) |
|||
m5 = (a11 + a12).strassen(b22) |
|||
m6 = (a21 - a11).strassen(b11 + b12) |
|||
m7 = (a12 - a22).strassen(b21 + b22) |
|||
c11 = m1 + m4 - m5 + m7 |
|||
c12 = m3 + m5 |
|||
c21 = m2 + m4 |
|||
c22 = m1 - m2 + m3 + m6 |
|||
return Matrix.block([[c11, c12], [c21, c22]]) |
|||
def round(self, ndigits: Optional[int] = None) -> Matrix: |
|||
return Matrix([[round(i, ndigits) for i in row] for row in self]) |
|||
@property |
|||
def shape(self) -> Shape: |
|||
cols = len(self[0]) if self else 0 |
|||
return Shape(len(self), cols) |
|||
def examples(): |
|||
a = Matrix( |
|||
[ |
|||
[1, 2], |
|||
[3, 4], |
|||
] |
|||
) |
|||
b = Matrix( |
|||
[ |
|||
[5, 6], |
|||
[7, 8], |
|||
] |
|||
) |
|||
c = Matrix( |
|||
[ |
|||
[1, 1, 1, 1], |
|||
[2, 4, 8, 16], |
|||
[3, 9, 27, 81], |
|||
[4, 16, 64, 256], |
|||
] |
|||
) |
|||
d = Matrix( |
|||
[ |
|||
[4, -3, 4 / 3, -1 / 4], |
|||
[-13 / 3, 19 / 4, -7 / 3, 11 / 24], |
|||
[3 / 2, -2, 7 / 6, -1 / 4], |
|||
[-1 / 6, 1 / 4, -1 / 6, 1 / 24], |
|||
] |
|||
) |
|||
e = Matrix( |
|||
[ |
|||
[1, 2, 3, 4], |
|||
[5, 6, 7, 8], |
|||
[9, 10, 11, 12], |
|||
[13, 14, 15, 16], |
|||
] |
|||
) |
|||
f = Matrix( |
|||
[ |
|||
[1, 0, 0, 0], |
|||
[0, 1, 0, 0], |
|||
[0, 0, 1, 0], |
|||
[0, 0, 0, 1], |
|||
] |
|||
) |
|||
print("Naive matrix multiplication:") |
|||
print(f" a * b = {a @ b}") |
|||
print(f" c * d = {(c @ d).round()}") |
|||
print(f" e * f = {e @ f}") |
|||
print("Strassen's matrix multiplication:") |
|||
print(f" a * b = {a.strassen(b)}") |
|||
print(f" c * d = {c.strassen(d).round()}") |
|||
print(f" e * f = {e.strassen(f)}") |
|||
if __name__ == "__main__": |
|||
examples() |
|||
</lang> |
|||
{{out}} |
|||
<pre> |
|||
Naive matrix multiplication: |
|||
a * b = [[19, 22], [43, 50]] |
|||
c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] |
|||
e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] |
|||
Strassen's matrix multiplication: |
|||
a * b = [[19, 22], [43, 50]] |
|||
c * d = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] |
|||
e * f = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] |
|||
</pre> |
</pre> |
||