Matrix vector multiplication along array axes

in a current project I have a large multidimensional array of shape (I,J,K,N) and a square matrix of dim N.

I need to perform a matrix vector multiplication of the last axis of the array with the square matrix.

So the obvious solution would be:

for i in range(I): for j in range(J): for k in range(K): arr[i,j,k] = mat.dot(arr[i,j,k])

but of course this is rather slow. So I also tried numpy's tensordot but had little success. I would expect that something like:

arr = tensordot(mat,arr,axes=((0,1),(3)))

should do the trick but I get a shape mismatch error.

Has someone a better solution or knows how to correctly use tensordot?

Thank you!


This should do what your loops, but with vectorized looping:

from numpy.core.umath_tests import matrix_multiply arr[..., np.newaxis] = matrix_multiply(mat, arr[..., np.newaxis])

matrix_multiply and its sister inner1d are hidden, undocumented, gems of numpy, although a full set of linear algebra gufuncs should see the light with numpy 1.8. matrix_multiply does matrix multiplication on the last two dimensions of its inputs, and broadcasting on the rest. The only tricky part is setting an additional dimension, so that it sees column vectors when multiplying, and adding it also on assignment back into array, so that there is no shape mismatch.


I think your for loop is wrong, and for this case dot seems to be enough:

# a is your IJKN # b is your NN c = dot(a, b)

Here c will be a IJKN array. If you want to sum over the last dimension to get the IJK array:

arr = dot(a,b).sum(axis=3)



  • Generating np.einsum evaluation graph
  • Combining element-wise and matrix multiplication with multi-dimensional arrays in NumPy
  • Numpy - multiple 3d array with a 2d array
  • Multiply 2D NumPy arrays element-wise and sum
  • Multiple scatterplots using Core Plot and Swift
  • SSLException: Connection has been shutdown: javax.net.ssl.SSLException: Tag mismatch
  • How can I get new CSRF token in LARAVEL by using ajax
  • plotting spatial points over a raster layer in r
  • Subclassing a Pandas DataFrame, updates?
  • Intel c/c++ compiler: “could not locate executable icc” (and ecc)
  • What is lua_len() alternative in Lua 5.1?
  • Einsum optimize fails for basic operation
  • How is SLOC counted by Delphi IDE?
  • Ray-tracing triangles
  • How to lookup value with multiple criteria in excel 2007 and newer
  • Is it possible to specialize on a static lifetime?
  • Reshape array on xAxis and fill with mean value in Python?
  • Reloading table causes flickering
  • How to remove the dot in to_char if the number is an integer
  • Quick Question About Get and Set
  • Python function to read variable length blocks of data from file while open
  • Is looping through all style sheets and classes a good idea in JavaScript?
  • Converting query results into DataFrame in python
  • redirect_to root_url and return unless @user.activated
  • vectorized indexing/slicing in numpy/scipy?
  • Wrong labels when plotting a time series pandas dataframe with matplotlib
  • Plotting line graph with factors in R
  • Overlapping controls in Windows XP
  • Zoom in and out of jPanel
  • nonblocking BIO_do_connect blocked when there is no internet connected
  • Eloquent paginate function in Slim 3 project using twig
  • Get data from AJAX - How to
  • Asynchronous UI Testing in Xcode With Swift
  • Bug in WPF DataGrid
  • Matplotlib draw Spline from multiple points
  • Run Powershell script from inside other Powershell script with dynamic redirection to file
  • How to CLICK on IE download dialog box i.e.(Open, Save, Save As…)
  • Bitwise OR returns boolean when one of operands is nil
  • MATLAB: Piecewise function in curve fitting toolbox using fittype
  • Django query for large number of relationships