Matrix-exponentiation operator

From Rosetta Code
Revision as of 12:23, 22 February 2008 by rosettacode>Badmadevil (→‎{{header|D}}: simplify code a bit and added a note)
Task
Matrix-exponentiation operator
You are encouraged to solve this task according to the task description, using any language you may know.

Most all programming languages have a built-in implementation of exponentiation for integer and real only.

The following programs demonstrates how to implement a "complex number" matrix exponentiation (**) as an operator.

ALGOL 68

main:(

  INT default upb=3;
  MODE VEC = [default upb]COMPL;
  MODE MAT = [default upb,default upb]COMPL;
 
  OP * = (VEC a,b)COMPL: (
      COMPL result:=0;
      FOR i FROM LWB a TO UPB a DO result+:= a[i]*b[i] OD;
      result
    );
 
  OP * = (VEC a, MAT b)VEC: ( # overload vec times matrix #
      [2 LWB b:2 UPB b]COMPL result;
      FOR j FROM 2 LWB b TO 2 UPB b DO result[j]:=a*b[,j] OD;
      result
    );
 
  OP * = (MAT a, b)MAT: ( # overload matrix times matrix #
      [LWB a:UPB a, 2 LWB b:2 UPB b]COMPL result;
      FOR k FROM LWB result TO UPB result DO result[k,]:=a[k,]*b OD;
      result
    );

  OP IDENTITY = (INT upb)MAT:(
    [upb,upb] COMPL out;
    FOR i TO upb DO 
      FOR j TO upb DO
        out[i,j]:= ( i=j |1|0)
      OD
    OD;
    out
  );
   # This is the task part. #
   OP ** = (MAT base, INT exponent)MAT: (
     BITS binary exponent:=BIN exponent ;
     MAT out := IF bits width ELEM binary exponent THEN base ELSE IDENTITY UPB base FI;
     MAT sq:=base;
 
     WHILE 
       binary exponent := binary exponent SHR 1;
       binary exponent /= BIN 0 
     DO
       sq := sq * sq; 
       IF bits width ELEM binary exponent THEN out := out * sq FI
     OD;
     out
   );
  PROC compl matrix printf= (FORMAT compl fmt, MAT m)VOID:(
    FORMAT vec fmt = $"("n(2 UPB m-1)(f(compl fmt)",")f(compl fmt)")"$;
    FORMAT matrix fmt = $x"("n(UPB m-1)(f(vec fmt)","lxx)f(vec fmt)");"$;
    # finally print the result #
    printf((matrix fmt,m))
  );
 
  FORMAT compl fmt = $-z.z,+z.z"i"$; # width of 4, with no leading '+' sign, 1 decimals #
  MAT matrix=((sqrt(0.5)I0         , sqrt(0.5)I0        , 0I0),
              (        0I-sqrt(0.5),         0Isqrt(0.5), 0I0),
              (        0I0         ,         0I0        , 0I1));

  printf(($" matrix ** "g(0)":"l$,24));
  compl matrix printf(compl fmt, matrix**24); print(newline)
)

Output:

matrix ** 24:
(( 1.0+0.0i, 0.0+0.0i, 0.0+0.0i),
 ( 0.0+0.0i, 1.0+0.0i, 0.0+0.0i),
 ( 0.0+0.0i, 0.0+0.0i, 1.0+0.0i));

D

This is a implementation by D.

module mxpow ;
import std.stdio ;
import std.string ;
import std.math ;

struct SqMx(int MSize = 3, T = creal) {
  alias T[MSize][MSize] Ax ;
  alias SqMx!(MSize, T) Mx ;
  static string fmt = "%8.3f" ;
  private Ax a ;
  static Mx opCall(Ax a){
    Mx m ;
    m.a[] = a[] ;
    return m ;
  }
  static Mx Identity() {
    Mx m ;
    for(int r = 0; r < MSize ; r++)
      for(int c = 0 ; c < MSize ; c++)
        m.a[r][c] = cast(T) (r == c ? 1 : 0) ;
    return m ;
  }
  string toString() { // pretty print
    string[MSize] s, t ;
    foreach(i, r; a) {
      foreach (j , e ; r)   
        s[j] = format(fmt, e) ;
      t[i] = join(s, ",") ;
    }
    return "<" ~ join(t,"\n ") ~ ">" ;
  }     
  Mx opMul(Mx b) {
    Mx d ;
    for(int r = 0 ; r < MSize ; r++)
      for(int c = 0 ; c < MSize ; c++) {
        d.a[r][c] = cast(T) 0 ;
        for(int k = 0 ; k < MSize ; k++)
          d.a[r][c] += a[r][k]*b.a[k][c] ;
      }
    return d ;
  }

This is the task part.

  // D does not have a ** operator, instead, ^ (bitwise Xor) is used.
  Mx opXor(int n){
    Mx d , sq ;

    if(n < 0)
      throw new Exception("Negative exponent not implemented") ;

    sq.a[] = this.a[] ; 
    d = Mx.Identity ;
    for( ;n > 0 ; sq = sq * sq, n >>= 1)
      if (n & 1)
        d = d * sq ;
    return d ;
  } 
  alias opXor pow ;	
}

alias SqMx!() M3 ;

void main() {
  real q = sqrt(0.5) ;
  M3 m = M3(cast(M3.Ax)
            [   q + 0*1.0Li,   q + 0*1.0Li, 0.0L + 0.0Li,
             0.0L - q*1.0Li,0.0L + q*1.0Li, 0.0L + 0.0Li,
             0.0L +   0.0Li,0.0L +   0.0Li, 0.0L + 1.0Li]) ;
  M3.fmt = "%5.2f" ;
  writefln("m ^ 23 =\n", m.pow(23)) ;  
  writefln("m ^ 24 =\n", m ^ 24) ;  
}

Output:

m ^ 23 =
< 0.71+ 0.00i, 0.00+ 0.71i, 0.00+ 0.00i
  0.71+ 0.00i, 0.00+-0.71i, 0.00+ 0.00i
  0.00+ 0.00i, 0.00+ 0.00i, 0.00+-1.00i>
m ^ 24 =
< 1.00+ 0.00i, 0.00+ 0.00i, 0.00+ 0.00i
  0.00+ 0.00i, 1.00+ 0.00i, 0.00+ 0.00i
  0.00+ 0.00i, 0.00+ 0.00i, 1.00+ 0.00i>

NOTE: In D, the commutativity of binary operator to be overloading is preseted. For instance, arithemic + * , bitwise & ^ | operators are commutative, - / % >> << >>> is non-commutative.
The exponential operator ^ chose in previous code happened to be commutative, which allow expression like 24 ^ m to be legal. If such expression is not allowed, either a non-comutative operator should be chose, or implement a corresponding opXXX_r overloading that may throw static assert/error.
Details see Operator Loading in D