Transforming Your Matplotlib Plots to Traditional Style

Photo by John Cameron on Unsplash

Matplotlib is a powerful library that enables users to quickly create production-ready plots with ease. However, it has certain opinions on how a plot should appear by default. While it’s always possible to customize these plots, it can sometimes be challenging to achieve the desired result. One common issue encountered is the default positioning of the axes at the borders of the plot. If you prefer a more traditional layout with the origin in the center and the axes passing through the origin, as well as arrows at the ends of the axes, some adjustments need to be made.

Daily Dose of Scientific Python

In the context of Matplotlib, what we typically refer to as an “axis” is called a “spine.” To achieve the desired appearance, we must manipulate the Spine objects. The primary tasks are to disable the left and top spines, and move the right and bottom spines to the origin. Afterward, we can plot the arrowheads. Here’s one method for achieving this:

import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()

ax.spines[["left", "bottom"]].set_position(("data", 0))
ax.spines[["top", "right"]].set_visible(False)

ax.plot(1, 0, ">k", transform=ax.get_yaxis_transform(), clip_on=False)
ax.plot(0, 1, "^k", transform=ax.get_xaxis_transform(), clip_on=False)

x = np.linspace(-0.5, 1., 100)
ax.plot(x, np.sin(x*np.pi), lw=2)

ax.text(1, -0.1, "$x$", transform=ax.get_yaxis_transform())
ax.text(-0.1, 1, "$f(x)$", transform=ax.get_xaxis_transform())

plt.show()

Let’s break down the code step-by-step. First, we create the figure and axes using the plt.subplots() function. The fig and ax objects represent the entire plot and the individual axes, respectively. "Axes" means "plot", which is an unfortunate terminology in matplotlib. Now, we will move the left and bottom spines to x=0 and y=0 . The set_position method is used with the arguments ("data", 0), which places the spines at the origin using the data coordinate system. We then hide the top and right spines using the set_visible method.

To draw arrows at the ends of the axes, we use the ax.plot() function. The transform parameter plays a crucial role here. Matplotlib uses different coordinate systems, such as data coordinates (which correspond to the actual data values) and axes coordinates (which range from 0 to 1, representing the proportion of the axes’ length).

In the code above, we want to place the arrows precisely at the end of each axis. To achieve this, we use a blended coordinate system for the x and y positions, which combines the data and axes coordinate systems. The ax.get_yaxis_transform() and ax.get_xaxis_transform() functions return the required blended coordinate systems. Consequently, the arrow positions are specified using data coordinates for one dimension and axes coordinates for the other.

We also set the clip_on parameter to False to prevent the arrow markers from being clipped, as they extend slightly beyond the axes limits.

Lastly, we add labels to the axes using the ax.text() function, which allows precise positioning of the text. While ax.set_xlabel() and ax.set_ylabel() can be used to label the axes, they don't provide the same level of control over the label positions. In this case, the ax.text() function works well for aligning the labels with the new axes positions.