MATLAB-like Waterfall Plots In Python

Photo by Mike Lewis HeadSmart Media on Unsplash

In the Python scientific world, matplotlib is the top dog when it comes to plotting. There are also other great packages, like Bokeh or plotly but matplotlib is still the one which is most popular and most widely used. The reason is probably that the creators of matplotlib designed it in such a way that it mimics the way you do plotting in MATLAB as far as possible. Considering this, it is all the more surprising that there is no equivalent to MATLAB’s waterfall plot implemented in matplotlib . At least it seems so. When I tried to make a waterfall-like plot in matplotlib I found out that it is indeed possible, but this feature is a bit hidden. In this plot, I will show you how you can easily make use of this somehow hidden feature.

What are waterfall plots in MATLAB?

Suppose you have a partial differential equation depending on a time variable tt and a space variable x . When you solve this equation numerically, you get a solution f(t,x) . For each discrete time step, you get a function of a single variable. A waterfall plot is a practical way to visualize these solutions. It’s a 3D plot, with axes t and x in the horizontal plane and f(t,x) plotted vertically. The good thing about waterfall plots is that they are simple. One could easily make a 3D surface plot, but all too often, the complexity of the surface plot hides what is truly going on. In a waterfall plot, however, there is just a curve for each time step. The surface plot is sliced parallel to planes of constant x . That doesn’t look as fancy as more complicated plots, but allows for a much better overview. An example of a waterfall plot is displayed below. In MATLAB, creating that plot is just a matter of calling the waterfall function. In matplotlib , there is no obvious way how to do that directly. But it turns out, that one can do it easily, after all, if you know how. That's what we will turn to in the next section.

MATLAB-like waterfall plot

Waterfall plots with matplotlib

The key to making waterfall plots in matplotlib is the plot_wireframe function. However, by default, it doesn't make slices, but creates a ... well... wireframe grid. But there is an option in that function that allows us to turn it into a waterfall plot. Let's go step by step. Let's suppose we have an exact solution for a PDE and we code it by defining the function f below. In practice, of course, we would have to solve the PDE by a suitable method.

def f(t, x):
    return np.exp(-(x-t**2+1)**2/(0.5**2*(t+1))) / (1+t**2)

Not we create an array with the data that we want to plot. We start with 1D-arrays for the coordinate axes and then turn them into 2D meshed grids:

import numpy as np

t = np.linspace(0, 2, 100)
x = np.linspace(-2, 2, 100)

y = [f(t_, x) for t_ in t] # y is a list of 1D-arrays, one per time step

Y = np.asarray(y) # now Y is a 2D array

T, X = np.meshgrid(t, x, indexing='ij') # ij important

Note that when we turned the 1D arrays x and t into meshed grid arrays X and T , we used the option indexing=ij . This treats the first index as the row index and the second as the column index. This is important because otherwise the syntax of the plot_wireframe function would be nonintuitive and you would have to play around until you get the order of the axes right. But with indexing=ij , everything works as you would intuitively expect. And not only for plot_wireframe ! In fact, I recommend using indexing=ij whenever you work with meshgrid . In my experience, this is always the more intuitive way and it has saved me a lot of time in the past.

Now we can start plotting. If you are using Jupyter, don’t forget to use the %matplotlib notebook statement because this allows you to interactively work with the displayed plot. This is very comfortable, in particular for 3D plots, where you want to change the view perspective easily. Let's first create a normal wireframe plot:

import matplotlib.pyplot as plt
%matplotlib notebook

ax = plt.figure().add_subplot(projection='3d')

ax.plot_wireframe(T, X, Y)
ax.set_xlabel('t')
ax.set_ylabel('x')
plt.show()
A normal wireframe plot

Well, as I said, this doesn’t give a good overview. Now let’s turn this into a waterfall plot. There are two options for plot_wireframe , called rstride (row stride) and cstride (column stride). The trick is to set the column stride to 0. This has the effect to suppress drawing any lines in the "column direction", i.e. x -direction:

import matplotlib.pyplot as plt
%matplotlib notebook

ax = plt.figure().add_subplot(projection='3d')

ax.plot_wireframe(T, X, Y, cstride=0)
ax.set_xlabel('t')
ax.set_ylabel('x')
plt.show()

This is slightly better, but still too much displayed information. Let’s filter the number of slices plotted in the t -direction (row-direction), by setting a rstride value:

import matplotlib.pyplot as plt
%matplotlib notebook

ax = plt.figure().add_subplot(projection='3d')

ax.plot_wireframe(T, X, Y, rstride=5, cstride=0)
ax.set_xlabel('t')
ax.set_ylabel('x')
plt.show()

Et voilà, we have the waterfall plot that we are used from MATLAB. And by changing the rstride value, we even have more flexibility!

Conclusion

In conclusion, the ability to create MATLAB-like waterfall plots in Python using matplotlib has been shown to be possible through the use of the plot_wireframe function. The key to a successful waterfall plot is to make use of the indexing='ij' option when turning the 1D arrays into a 2D mesh grid, and to use the rstride and cstride options to control the spacing of the slices. By following these steps, Python users can now create simple, yet effective, waterfall plots to visualize solutions to partial differential equations.