The matrix multiplication function, numpy.dot(), only takes two arguments. That means to multiply more than two arrays together you end up with nested function calls which are hard to read:

dot(dot(dot(a,b),c),d)

versus infix notation where you'd just be able to write

a*b*c*d

There are a couple of ways to define an 'mdot' function that acts like dot but accepts more than two arguments. Using one of these allows you to write the above expression as

mdot(a,b,c,d)

## Using reduce

The simplest way it to just use reduce.

1 def mdot(*args):
2     return reduce(numpy.dot, args)

Or use the equivalent loop (which is apparently the preferred style for Py3K):

1 def mdot(*args):
2     ret = args[0]
3     for a in args[1:]:
4         ret = dot(ret,a)
5     return ret

This will always give you left to right associativity, i.e. the expression is interpreted as (((a*b)*c)*d).

You also can make a right-associative version of the loop:

1 def mdotr(*args):
2     ret = args[-1]
3     for a in reversed(args[:-1]):
4         ret = dot(a,ret)
5     return ret

which evaluates as (a*(b*(c*d))). But sometimes you'd like to have finer control since the order in which matrix multiplies are performed can have a big impact on performance. The next version gives that control.

## Controlling order of evaluation

If we're willing to sacrifice Numpy's ability to treat tuples as arrays, we can use tuples as grouping constructs. This version of mdot allows syntax like this:

mdot(a,((b,c),d))

to control the order in which the pairwise dot calls are made.

1 import types
2 import numpy
3 def mdot(*args):
4    """Multiply all the arguments using matrix product rules.
5    The output is equivalent to multiplying the arguments one by one
6    from left to right using dot().
7    Precedence can be controlled by creating tuples of arguments,
8    for instance mdot(a,((b,c),d)) multiplies a (a*((b*c)*d)).
9    Note that this means the output of dot(a,b) and mdot(a,b) will differ if
10    a or b is a pure tuple of numbers.
11    """
12    if len(args)==1:
13        return args[0]
14    elif len(args)==2:
15        return _mdot_r(args[0],args[1])
16    else:
17        return _mdot_r(args[:-1],args[-1])
18
19 def _mdot_r(a,b):
20    """Recursive helper for mdot"""
21    if type(a)==types.TupleType:
22        if len(a)>1:
23            a = mdot(*a)
24        else:
25            a = a[0]
26    if type(b)==types.TupleType:
27        if len(b)>1:
28            b = mdot(*b)
29        else:
30            b = b[0]
31    return numpy.dot(a,b)

## Multiply

Note that the elementwise multiplication function numpy.multiply has the same two-argument limitation as numpy.dot. The exact same generalized forms can be defined for multiply.

Left associative versions:

1 def mmultiply(*args):
2     return reduce(numpy.multiply, args)

1 def mmultiply(*args):
2     ret = args[0]
3     for a in args[1:]:
4         ret = multiply(ret,a)
5     return ret

Right-associative version:

1 def mmultiplyr(*args):
2     ret = args[-1]
3     for a in reversed(args[:-1]):
4         ret = multiply(a,ret)
5     return ret

Version using tuples to control order of evaluation:

1 import types
2 import numpy
3 def mmultiply(*args):
4    """Multiply all the arguments using elementwise product.
5    The output is equivalent to multiplying the arguments one by one
6    from left to right using multiply().
7    Precedence can be controlled by creating tuples of arguments,
8    for instance mmultiply(a,((b,c),d)) multiplies a (a*((b*c)*d)).
9    Note that this means the output of multiply(a,b) and mmultiply(a,b) will differ if
10    a or b is a pure tuple of numbers.
11    """
12    if len(args)==1:
13        return args[0]
14    elif len(args)==2:
15        return _mmultiply_r(args[0],args[1])
16    else:
17        return _mmultiply_r(args[:-1],args[-1])
18
19 def _mmultiply_r(a,b):
20    """Recursive helper for mmultiply"""
21    if type(a)==types.TupleType:
22        if len(a)>1:
23            a = mmultiply(*a)
24        else:
25            a = a[0]
26    if type(b)==types.TupleType:
27        if len(b)>1:
28            b = mmultiply(*b)
29        else:
30            b = b[0]
31    return numpy.multiply(a,b)

Cookbook/MultiDot (last edited 2007-03-24 21:32:12 by BillBaxter)