MATLAB-like Waterfall Plots In Python
In the Python scientific world, matplotlib is the top dog when it comes to plotting. There are also other great packages, like Bokeh or plotly but matplotlib is still the one which is most popular and most widely used. The reason is probably that the creators of matplotlib designed it in such a way that it mimics the way you do plotting in MATLAB as far as possible. Considering this, it is all the more surprising that there is no equivalent to MATLAB’s waterfall plot implemented in matplotlib . At least it seems so. When I tried to make a waterfall-like plot in matplotlib I found out that it is indeed possible, but this feature is a bit hidden. In this plot, I will show you how you can easily make use of this somehow hidden feature.
What are waterfall plots in MATLAB?
Suppose you have a partial differential equation depending on a time variable tt and a space variable
x
. When you solve this equation numerically, you get a solution
f(t,x)
. For each discrete time step, you get a function of a single variable. A waterfall plot is a practical way to visualize these solutions. It’s a 3D plot, with axes
t
and
x
in the horizontal plane and
f(t,x)
plotted vertically. The good thing about waterfall plots is that they are simple. One could easily make a 3D surface plot, but all too often, the complexity of the surface plot hides what is truly going on. In a waterfall plot, however, there is just a curve for each time step. The surface plot is sliced parallel to planes of constant
x
. That doesn’t look as fancy as more complicated plots, but allows for a much better overview. An example of a waterfall plot is displayed below. In MATLAB, creating that plot is just a matter of calling the
waterfall
function. In
matplotlib
, there is no obvious way how to do that directly. But it turns out, that one can do it easily, after all, if you know how. That's what we will turn to in the next section.
Waterfall plots with matplotlib
The key to making waterfall plots in
matplotlib
is the
plot_wireframe
function. However, by default, it doesn't make slices, but creates a ... well... wireframe grid. But there is an option in that function that allows us to turn it into a waterfall plot. Let's go step by step. Let's suppose we have an exact solution for a PDE and we code it by defining the function
f
below. In practice, of course, we would have to solve the PDE by a suitable method.
def f(t, x):
return np.exp(-(x-t**2+1)**2/(0.5**2*(t+1))) / (1+t**2)
Not we create an array with the data that we want to plot. We start with 1D-arrays for the coordinate axes and then turn them into 2D meshed grids:
import numpy as np
t = np.linspace(0, 2, 100)
x = np.linspace(-2, 2, 100)
y = [f(t_, x) for t_ in t] # y is a list of 1D-arrays, one per time step
Y = np.asarray(y) # now Y is a 2D array
T, X = np.meshgrid(t, x, indexing='ij') # ij important
Note that when we turned the 1D arrays
x
and
t
into meshed grid arrays
X
and
T
, we used the option
indexing=ij
. This treats the first index as the row index and the second as the column index. This is important because otherwise the syntax of the
plot_wireframe
function would be nonintuitive and you would have to play around until you get the order of the axes right. But with
indexing=ij
, everything works as you would intuitively expect. And not only for
plot_wireframe
! In fact, I recommend using
indexing=ij
whenever you work with
meshgrid
. In my experience, this is
always
the more intuitive way and it has saved me a lot of time in the past.
Now we can start plotting. If you are using Jupyter, don’t forget to use the
%matplotlib notebook
statement because this allows you to interactively work with the displayed plot. This is very comfortable, in particular for 3D plots, where you want to change the view perspective easily. Let's first create a normal wireframe plot:
import matplotlib.pyplot as plt
%matplotlib notebook
ax = plt.figure().add_subplot(projection='3d')
ax.plot_wireframe(T, X, Y)
ax.set_xlabel('t')
ax.set_ylabel('x')
plt.show()
Well, as I said, this doesn’t give a good overview. Now let’s turn this into a waterfall plot. There are two options for
plot_wireframe
, called
rstride
(row stride) and
cstride
(column stride). The trick is to set the column stride to 0. This has the effect to suppress drawing any lines in the "column direction", i.e.
x
-direction:
import matplotlib.pyplot as plt
%matplotlib notebook
ax = plt.figure().add_subplot(projection='3d')
ax.plot_wireframe(T, X, Y, cstride=0)
ax.set_xlabel('t')
ax.set_ylabel('x')
plt.show()
This is slightly better, but still too much displayed information. Let’s filter the number of slices plotted in the
t
-direction (row-direction), by setting a
rstride
value:
import matplotlib.pyplot as plt
%matplotlib notebook
ax = plt.figure().add_subplot(projection='3d')
ax.plot_wireframe(T, X, Y, rstride=5, cstride=0)
ax.set_xlabel('t')
ax.set_ylabel('x')
plt.show()
Et voilà, we have the waterfall plot that we are used from MATLAB. And by changing the
rstride
value, we even have more flexibility!
Conclusion
In conclusion, the ability to create MATLAB-like waterfall plots in Python using
matplotlib
has been shown to be possible through the use of the plot_wireframe function. The key to a successful waterfall plot is to make use of the
indexing='ij'
option when turning the 1D arrays into a 2D mesh grid, and to use the
rstride
and
cstride
options to control the spacing of the slices. By following these steps, Python users can now create simple, yet effective, waterfall plots to visualize solutions to partial differential equations.