Matplotlib is the go to for visualization in Python. It’s simple and well known. But I have had a long continuing confusion with one part of Matplotlib. That is, sometimes you see one of the three in the below segments of code. What are the differences between these methods? What are they exactly for? Sometimes I’ve been able to create basically the same chart or graph with either of these methods. I want to clear up some of this confusion, for you and (mainly) for myself. I have also included a Jupyter Notebook at the bottom that might be easier to read than this.
import matplotlib.pyplot as plt plt.plot(x, y)
ax = plt.subplot() ax.plot(x, y)
fig = plt.figure() visualization = fig.add_subplot(111) visualization.plot(x, y)
Matplotlib creates objects which fit within a hierarchy. These objects are figures, axes, and subplots. Axes also contain axises. We can think of a figure as a canvas on which you can paint multiple plots. It is the overall window or page on which everything is drawn. The “plots” on the canvas in turn are called axes. And in turn yet again, these axes are made up of two to three axis. Axes are what we generally think of as a plot; each contains two to three axis objects as well as a title, x label, and y label. A figure can have multiple axes. Subplots add more than one plot in one figure, usually taking three arguments of nrows, ncols and index.
The simplest plot sort of ignores the figure and just calls the plot function. One would use this if you were only plotting one thing.
plt.plot(x, y) plt.title("Simple Plot") plt.xlabel("X") plt.ylabel("Y") plt.show();
One can do the same thing, but adding a figure to it.
plt.figure(figsize=(15,5)) plt.plot(x, y) plt.show()
If you want to have multiple plots, you can add a subplot, specifying where in an invisible box it would go.
plt.subplot(1,2,1) plt.plot(x, y) plt.title("Subplot I") plt.xlabel("x") plt.ylabel("y")plt.subplot(1,2,2) plt.plot(x, y) plt.title("Subplot II") plt.xlabel("x") plt.ylabel("y")plt.suptitle("Subplots") plt.show()
Finally, you can bring them all together.
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10,10)) ax[0,1].plot(x,y) ax[1,0].plot(x,y) ax[0,1].set_title("Subplot I") ax[1,0].set_title("Subplot II") plt.show()
Let’s also show how to do these just as references.
plt.hist(data) plt.title("Histogram") plt.show()
plt.xlim(0,10) plt.ylim(0,5000) plt.scatter(x, y) plt.title("Scatter") plt.xlabel("x") plt.ylabel("y") plt.show()
from mpl_toolkits import mplot3dax = plt.axes(projection='3d') ax.scatter3D(x,y) ax.set_xlabel("x") ax.set_ylabel("y") plt.show()
there are numerous ways to style your plots. here are some examples of how to do so manually.
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10,10), edgecolor="green")ax[0,1].plot(x,y, color='orange', linestyle=':', linewidth=1) ax[0,1].grid(color='red', linestyle='-', linewidth=1) ax[0,1].set_title("Plot")ax[0,0].plot(x,y, color='green', linestyle='--', linewidth=1) ax[0,0].set_title("Plot")ax[1,0].scatter(x,y, color='magenta', marker='^') ax[1,0].grid(color='purple', linestyle='-', linewidth=1) ax[1,0].set_title("Scatter")ax[1,1].scatter(x,y, color='cyan', marker='*') ax[1,1].set_title("Scatter") plt.show()
One doesn’t have to style everything manually. There are built in style sheets that you can use. In addition, one can create one’s own style sheets.
plt.style.use('fivethirtyeight')plt.hist(data) plt.title("Histogram") plt.show()
There are other methods and styling systems, including Seaborn, which highly recommend, as well as Plotly, Bokeh, and so forth.