Matrix chain multiplication: Difference between revisions

From Rosetta Code
Content added Content deleted
mNo edit summary
mNo edit summary
Line 15: Line 15:
;Task
;Task


Write a function which, given a list of the successive dimension of matrices A1, A2... An, of arbitrary length, returns the optimal way to compute the matrix product, and the total cost. The list does not duplicate shared dimensions: for the previous example of matrices A,B,C, one will only pass the list [5,6,3,1]. Hence, a product of n matrices is represented by a list of n+1 dimensions.
Write a function which, given a list of the successive dimensions of matrices A1, A2... An, of arbitrary length, returns the optimal way to compute the matrix product, and the total cost. The list does not duplicate shared dimensions: for the previous example of matrices A,B,C, one will only pass the list [5,6,3,1]. Hence, a product of n matrices is represented by a list of n+1 dimensions.


Try this function on the following two lists:
Try this function on the following two lists:

Revision as of 12:22, 12 April 2018

Matrix chain multiplication is a draft programming task. It is not yet considered ready to be promoted as a complete task, for reasons that should be found in its talk page.
Problem

Using the most straightfoward algorithm (which we assume here), computing the product of two matrices of dimensions (n1,n2) and (n2,n3) requires n1*n2*n3 FMA operations. The number of operations required to compute the product of matrices A1, A2... An depends on the order of matrix multiplications, hence on where parens are put. Remember that the matrix product is associative, but not commutative, hence only the parens can be moved.

For instance, with four matrices, one can compute A(B(CD)), A((BC)D), (AB)(CD), (A(BC))D, (AB)C)D. The number of different ways to put the parens is a Catalan number, and grows exponentially with the number of factors.

Here is an example of computation of the total cost, for matrices A(5,6), B(6,3), C(3,1):

  • AB costs 5*6*3=90 and produces a matrix of dimensions (5,3), then (AB)C costs 5*3*1=15. The total cost is 105.
  • BC costs 6*3*1=18 and produces a matrix of dimensions (6,1), then A(BC) costs 5*6*1=30. The total cost is 48.

In this case, computing (AB)C requires more than twice as many operations as A(BC). The difference can be much more dramatic in real cases.

Task

Write a function which, given a list of the successive dimensions of matrices A1, A2... An, of arbitrary length, returns the optimal way to compute the matrix product, and the total cost. The list does not duplicate shared dimensions: for the previous example of matrices A,B,C, one will only pass the list [5,6,3,1]. Hence, a product of n matrices is represented by a list of n+1 dimensions.

Try this function on the following two lists:

  • [1, 5, 25, 30, 100, 70, 2, 1, 100, 250, 1, 1000, 2]
  • [1000, 1, 500, 12, 1, 700, 2500, 3, 2, 5, 14, 10]

To solve the task, it's possible, but not required, to write a function that enumerates all possible ways to parenthesize the product. This is not optimal because of the many duplicated computations, and this task is a classic application of dynamic programming.

See also Matrix chain multiplication on Wikipedia.

Python

Will will solve the task in three steps:

1) Enumerate all ways to parenthesize (in a recursive generator), and for each one compute the cost. Then simply look up the minimal cost.

2) Merge the enumeration and the cost function in a recursive cost optimizing function. The computation is roughly the same, but it's much faster as some steps are removed.

3) The recursive solution has many duplicates computations. Memoize the previous function: this yields a dynamic programming approach.

Enumeration of parenthesizations

<lang python>def parens(n):

   def aux(n, k):
       if n == 1:
           yield k
       elif n == 2:
           yield [k, k + 1]
       else:
           a = []
           for i in range(1, n):
               for u in aux(i, k):
                   for v in aux(n - i, k + i):
                       yield [u, v]
   yield from aux(n, 0)</lang>

Example (in the same order as in the task description)

<lang python>for u in parens(4):

   print(u)

[0, [1, [2, 3]]] [0, [[1, 2], 3]] [[0, 1], [2, 3]] [[0, [1, 2]], 3] [[[0, 1], 2], 3]</lang>

And here is the optimization step:

<lang python>def optim1(a):

   def cost(k):
       if type(k) is int:
           return 0, a[k], a[k + 1]
       else:
           s1, p1, q1 = cost(k[0])
           s2, p2, q2 = cost(k[1])
           assert q1 == p2
           return s1 + s2 + p1 * q1 * q2, p1, q2
   cmin = None
   n = len(a) - 1
   for u in parens(n):
       c, p, q = cost(u)
       if cmin is None or c < cmin:
           cmin = c
           umin = u
   return cmin, umin</lang>

Recursive cost optimization

The previous function optim1 already used recursion, but only to compute the cost of a given parens configuration, whereas another function (a generator actually) provides these configurations. Here we will do both recursively in the same function, avoiding the computation of configurations altogether.

<lang python>def optim2(a):

   def aux(n, k):
       if n == 1:
           p, q = a[k:k + 2]
           return 0, p, q, k
       elif n == 2:
           p, q, r = a[k:k + 3]
           return p * q * r, p, r, [k, k + 1]
       else:
           m = None
           p = a[k]
           q = a[k + n]
           for i in range(1, n):
               s1, p1, q1, u1 = aux(i, k)
               s2, p2, q2, u2 = aux(n - i, k + i)
               assert q1 == p2
               s = s1 + s2 + p1 * q1 * q2
               if m is None or s < m:
                   m = s
                   u = [u1, u2]
           return m, p, q, u
   s, p, q, u = aux(len(a) - 1, 0)
   return s, u</lang>

Memoized recursive call

The only difference between optim2 and optim3 is the @memoize decorator. Yet the algorithm is way faster with this. According to Wikipedia, the complexity falls from O(2^n) to O(n^3). This is confirmed by plotting log(time) vs log(n) for n up to 580 (this needs changing Python's recursion limit).

<lang python>def memoize(f):

   h = {}
   def g(*u):
       if u in h:
           return h[u]
       else:
           r = f(*u)
           h[u] = r
           return r
   return g

def optim3(a):

   @memoize
   def aux(n, k):
       if n == 1:
           p, q = a[k:k + 2]
           return 0, p, q, k
       elif n == 2:
           p, q, r = a[k:k + 3]
           return p * q * r, p, r, [k, k + 1]
       else:
           m = None
           p = a[k]
           q = a[k + n]
           for i in range(1, n):
               s1, p1, q1, u1 = aux(i, k)
               s2, p2, q2, u2 = aux(n - i, k + i)
               assert q1 == p2
               s = s1 + s2 + p1 * q1 * q2
               if m is None or s < m:
                   m = s
                   u = [u1, u2]
           return m, p, q, u
   s, p, q, u = aux(len(a) - 1, 0)
   return s, u</lang>

Putting all together

<lang python>import time

u = [[1, 5, 25, 30, 100, 70, 2, 1, 100, 250, 1, 1000, 2],

    [1000, 1, 500, 12, 1, 700, 2500, 3, 2, 5, 14, 10]]

for a in u:

   print(a)
   print()
   print("function     time       cost   parens  ")
   print("-" * 90)
   for f in [optim1, optim2, optim3]:
       t1 = time.clock()
       s, u = f(a)
       t2 = time.clock()
       print("%s %10.3f %10d   %s" % (f.__name__, 1000 * (t2 - t1), s, u))
   print()</lang>

Output (timings are in milliseconds)

[1, 5, 25, 30, 100, 70, 2, 1, 100, 250, 1, 1000, 2]

function     time       cost   parens
------------------------------------------------------------------------------------------
optim1    838.636      38120   [[[[[[[[0, 1], 2], 3], 4], 5], 6], [7, [8, 9]]], [10, 11]]
optim2     80.628      38120   [[[[[[[[0, 1], 2], 3], 4], 5], 6], [7, [8, 9]]], [10, 11]]
optim3      0.373      38120   [[[[[[[[0, 1], 2], 3], 4], 5], 6], [7, [8, 9]]], [10, 11]]

[1000, 1, 500, 12, 1, 700, 2500, 3, 2, 5, 14, 10]

function     time       cost   parens
------------------------------------------------------------------------------------------
optim1    223.186    1773740   [0, [[[[[[1, 2], 3], [[[4, 5], 6], 7]], 8], 9], 10]]
optim2     27.660    1773740   [0, [[[[[[1, 2], 3], [[[4, 5], 6], 7]], 8], 9], 10]]
optim3      0.307    1773740   [0, [[[[[[1, 2], 3], [[[4, 5], 6], 7]], 8], 9], 10]]