More Matplotlib#

Matplotlib is the dominant plotting / visualization package in python. It is important to learn to use it well. In the last lecture, we saw some basic examples in the context of learning numpy. This week, we dive much deeper. The goal is to understand how matplotlib represents figures internally.

from matplotlib import pyplot as plt
%matplotlib inline

Figure and Axes#

The figure is the highest level of organization of matplotlib objects. If we want, we can create a figure explicitly.

fig = plt.figure()
<Figure size 640x480 with 0 Axes>
fig = plt.figure(figsize=(13, 5))
<Figure size 1300x500 with 0 Axes>
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
../../_images/95332241836f94f62b002d50f9afe5893175893db63906b19ae2d01c6c585f61.png
fig = plt.figure()
ax = fig.add_axes([0, 0, 0.5, 1])
../../_images/43be47b2f6992b55f036d2dc1bbc1eeb74bed9d27a20158b585567a21482b7f4.png
fig = plt.figure()
ax1 = fig.add_axes([0, 0, 0.5, 1])
ax2 = fig.add_axes([0.5, 0, 0.3, 0.5], facecolor='g')
../../_images/a1f8444f76c02f32eb9a99773e858f314032b0b94fa0c92464321343aaa7ac15.png

Subplots#

Subplot syntax is one way to specify the creation of multiple axes.

fig = plt.figure()
axes = fig.subplots(nrows=2, ncols=3)
../../_images/5447995a59d425265c743d34f618c473bb068a2fdf21a74db9f4edcfa9036c5b.png
fig = plt.figure(figsize=(1, 1))
axes = fig.subplots(nrows=2, ncols=3)
../../_images/af3918bc982b7819b3f035b45e26c494d0b2f3d042e717038b8b2f0a87ff3378.png
axes
array([[<Axes: >, <Axes: >, <Axes: >],
       [<Axes: >, <Axes: >, <Axes: >]], dtype=object)
axes[0,0]
<Axes: >

There is a shorthand for doing this all at once.

This is our recommended way to create new figures!

fig, ax = plt.subplots()
../../_images/b5d9806ae80834ab30a4f0bfd84ea4c0e45b54a9a52f39bda48edf1319ab202f.png
ax
<Axes: >
fig, axes = plt.subplots(ncols=2, figsize=(8, 4), subplot_kw={'facecolor': 'g'})
../../_images/dc062d805a01425dcf0e0af10b4fb334653161a7fb782c9512551eb769d6a094.png
axes
array([<Axes: >, <Axes: >], dtype=object)

Drawing into Axes#

All plots are drawn into axes. It is easiest to understand how matplotlib works if you use the object-oriented style.

# create some data to plot
import numpy as np
x = np.linspace(-np.pi, np.pi, 100)
y = np.cos(x)
z = np.sin(6*x)
plt.plot(x,y)
[<matplotlib.lines.Line2D at 0x7aa095e08f50>]
../../_images/719a12c6f19aa2a7d49e4b2ca5b2fc70bd45888ce6e42fb37a8aafa539695d62.png
fig, ax = plt.subplots()
plt.plot(x,y)
[<matplotlib.lines.Line2D at 0x7aa095bb19a0>]
../../_images/719a12c6f19aa2a7d49e4b2ca5b2fc70bd45888ce6e42fb37a8aafa539695d62.png
ax
<Axes: >
ax.plot(x, y)
[<matplotlib.lines.Line2D at 0x7aa095bb3ef0>]
fig
../../_images/b3830be76d6b0cf33040a9edaf1f89c2fa9e839aa3844d2d7993e9df04aa485e.png

This does the same thing as

plt.plot(x, y)
[<matplotlib.lines.Line2D at 0x7aa095c23140>]
../../_images/719a12c6f19aa2a7d49e4b2ca5b2fc70bd45888ce6e42fb37a8aafa539695d62.png

This starts to matter when we have multiple axes to worry about.

fig, axes = plt.subplots(figsize=(8, 4), ncols=2)
ax0, ax1 = axes
ax0.plot(x, y)
ax1.plot(x, z)
[<matplotlib.lines.Line2D at 0x7aa095c51ca0>]
../../_images/9bb952f944d34d114e95d8c7edc832e335f27d2ee91e07befc5a35b1eba22c05.png

Labeling Plots#

fig, axes = plt.subplots(figsize=(8, 4), ncols=2)
ax0, ax1 = axes

ax0.plot(x, y)
ax0.set_xlabel('x')
ax0.set_ylabel('y')
ax0.set_title('x vs. y')

ax1.plot(x, z)
ax1.set_xlabel('x')
ax1.set_ylabel('z')
ax1.set_title('x vs. z')

# squeeze everything in
plt.tight_layout()
../../_images/951a1b0b11c17a2b8c0f231f453e823054844be6fa0c6d7c09385de05bf454b1.png

Customizing Line Plots#

fig, ax = plt.subplots()
ax.plot(x, y, x, z)
[<matplotlib.lines.Line2D at 0x7aa09e106240>,
 <matplotlib.lines.Line2D at 0x7aa09e107b60>]
../../_images/d425367953df2a32ed12e36c10e42a5181ddf8cc15137de8ddca4c53393d0f49.png
fig, ax = plt.subplots()
ax.plot(x, y)
ax.plot(x, z)
[<matplotlib.lines.Line2D at 0x7aa09e075af0>]
../../_images/d425367953df2a32ed12e36c10e42a5181ddf8cc15137de8ddca4c53393d0f49.png
fig, ax = plt.subplots()
ax.plot(x, z)
ax.plot(x, y)
[<matplotlib.lines.Line2D at 0x7aa09e1eca40>]
../../_images/99672c0a8f90e31553bfd83f1ad2e270720b1dfc567e83510d20a81234efb91d.png

It’s simple to switch axes

fig, ax = plt.subplots()
ax.plot(y, x, z, x)
[<matplotlib.lines.Line2D at 0x7aa09e285f40>,
 <matplotlib.lines.Line2D at 0x7aa09e285a30>]
../../_images/6fb49cf49628e3e98d55302545a99b2aa0f400c4bdad98896262f0d0adbd33ea.png

A “parametric” graph:

fig, ax = plt.subplots()
ax.plot(y, z)
[<matplotlib.lines.Line2D at 0x7aa0959cf080>]
../../_images/636aa75b8b25ede17a7396b62a4be55c659525425354df5e20c8756791a4a440.png

Line Styles#

fig, axes = plt.subplots(figsize=(16, 5), ncols=3)
axes[0].plot(x, y, linestyle='dashed')
axes[0].plot(x, z, linestyle='--')

axes[1].plot(x, y, linestyle='dotted')
axes[1].plot(x, z, linestyle=':')

axes[2].plot(x, y, linestyle='dashdot', linewidth=5)
axes[2].plot(x, z, linestyle='-.', linewidth=0.5)
[<matplotlib.lines.Line2D at 0x7aa0958d4050>]
../../_images/1f05cfb0ae0498bf6584249a2a1772f95ab78da3d291770b59c8182d0e288d72.png

Colors#

As described in the colors documentation, there are some special codes for commonly used colors:

  • b: blue

  • g: green

  • r: red

  • c: cyan

  • m: magenta

  • y: yellow

  • k: black

  • w: white

fig, ax = plt.subplots()
ax.plot(x, y, color='k')
ax.plot(x, z, color='r')
[<matplotlib.lines.Line2D at 0x7aa0954c1130>]
../../_images/c01b7441883ade7919dbdfecdaac98b2dc08d716245b675cee602b3baf0e3b73.png

Other ways to specify colors:

fig, axes = plt.subplots(figsize=(16, 5), ncols=3)

# grayscale
axes[0].plot(x, y, color='0.8')
axes[0].plot(x, z, color='0.2')

# RGB tuple
axes[1].plot(x, y, color=(1, 0, 0.7))
axes[1].plot(x, z, color=(0, 0.4, 0.3))

# HTML hex code
axes[2].plot(x, y, color='#00dcba')
axes[2].plot(x, z, color='#b029ee')
[<matplotlib.lines.Line2D at 0x7aa0957a5e50>]
../../_images/982e68970129639b462c316f2e41d72d8b996e171d398a9258dbe36d080136fd.png

There is a default color cycle built into matplotlib.

plt.rcParams['axes.prop_cycle']
'color'
'#1f77b4'
'#ff7f0e'
'#2ca02c'
'#d62728'
'#9467bd'
'#8c564b'
'#e377c2'
'#7f7f7f'
'#bcbd22'
'#17becf'
fig, ax = plt.subplots(figsize=(12, 10))
for factor in np.linspace(0.2, 1, 11):
    ax.plot(x, factor*y)
../../_images/a710dc1e16c7149b68b975bc94fb3ff0330787b5f2642724d2dfc19df4359f72.png

Markers#

There are lots of different markers availabile in matplotlib!

fig, axes = plt.subplots(figsize=(12, 5), ncols=2)

axes[0].plot(x[:20], y[:20], marker='.')
axes[0].plot(x[:20], z[:20], marker='o')

axes[1].plot(x[:20], z[:20], marker='^',
             markersize=10, markerfacecolor='r',
             markeredgecolor='k')
[<matplotlib.lines.Line2D at 0x7aa095761100>]
../../_images/fc1f3829e7f0b0eae599cc3261d8750607b3573987030bba6b75eeb2c5cbe9a9.png

Label, Ticks, and Gridlines#

fig, ax = plt.subplots(figsize=(12, 7))
ax.plot(x, y)

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_title(r'A complicated math function: $f(x) = \cos(x)$')

ax.set_xticks(np.pi * np.array([-1, 0, 1]))
ax.set_xticklabels([r'$-\pi$', '0', r'$\pi$'])
ax.set_yticks([-1, 0, 1])

ax.set_yticks(np.arange(-1, 1.1, 0.2), minor=True)
#ax.set_xticks(np.arange(-3, 3.1, 0.2), minor=True)

ax.grid(which='minor', linestyle='--')
ax.grid(which='major', linewidth=2)
../../_images/779059473c550401f1309f270f72944544c1e05a6e783a92ae653e2f600b47c2.png

Axis Limits#

fig, ax = plt.subplots()
ax.plot(x, y, x, z)
ax.set_xlim(-5, 5)
ax.set_ylim(-3, 3)
(-3.0, 3.0)
../../_images/f7a2db8a19e47b3d68d7e22cb847c5a6af6f35db6a55f6050139c0b56db67d13.png

Text Annotations#

fig, ax = plt.subplots()
ax.plot(x, y)
ax.text(-3, 0.3, 'hello world')
ax.annotate('the maximum', xy=(0, 1),
             xytext=(0, 0), arrowprops={'facecolor': 'k'})
Text(0, 0, 'the maximum')
../../_images/86760be3349fec32575ec48faa4f5adf2050a720ce682b5f31e70d3864fe73b9.png

Other 1D Plots#

Scatter Plots#

fig, ax = plt.subplots()

splot = ax.scatter(y, z, c=x, s=(100*z**2 + 5))
fig.colorbar(splot)
<matplotlib.colorbar.Colorbar at 0x7aa095ac8ef0>
../../_images/9a3713e91d9e78b07d919a7baa3889524516a50d2eb941169e7becc6ccf1d06d.png

Bar Plots#

labels = ['first', 'second', 'third']
values = [10, 5, 30]

fig, axes = plt.subplots(figsize=(10, 5), ncols=2)
axes[0].bar(labels, values)
axes[1].barh(labels, values)
<BarContainer object of 3 artists>
../../_images/351fc35e86f716b9f6d4eaddcdbe78a9c39af16b478b205a6cc2928c8d26eda1.png

2D Plotting Methods#

imshow#

x1d = np.linspace(-2*np.pi, 2*np.pi, 100)
y1d = np.linspace(-np.pi, np.pi, 50)
xx, yy = np.meshgrid(x1d, y1d)
f = np.cos(xx) * np.sin(yy)
print(f.shape)
(50, 100)
fig, ax = plt.subplots(figsize=(12,4), ncols=2)
ax[0].imshow(f)
ax[1].imshow(f, origin='lower')
<matplotlib.image.AxesImage at 0x7aa095763d40>
../../_images/2e7888363aa682dfbd66144ff88c245e62e8e3cbd93bcff6a4d248f67bb4904a.png

pcolormesh#

fig, ax = plt.subplots(ncols=2, figsize=(12, 5))
pc0 = ax[0].pcolormesh(x1d, y1d, f)
pc1 = ax[1].pcolormesh(xx, yy, f)
fig.colorbar(pc0, ax=ax[0])
fig.colorbar(pc1, ax=ax[1])
<matplotlib.colorbar.Colorbar at 0x7aa09436bb00>
../../_images/73962560db553c76ca92bf4b0f7d16931d1cd54c5eb140ed3cee9e2a0202b125.png
x_sm, y_sm, f_sm = xx[:10, :10], yy[:10, :10], f[:10, :10]

fig, ax = plt.subplots(figsize=(12,5), ncols=2)

# last row and column ignored!
ax[0].pcolormesh(x_sm, y_sm, f_sm, edgecolors='k')

# same!
ax[1].pcolormesh(x_sm, y_sm, f_sm[:-1, :-1], edgecolors='k')
<matplotlib.collections.QuadMesh at 0x7aa08f70ef90>
../../_images/b6aa82759b55b95e9cfd57c39927c1d0d420f72c6e7d2f6c1dcd701a16940755.png
y_distorted = y_sm*(1 + 0.1*np.cos(6*x_sm))

plt.figure(figsize=(12,6))
plt.pcolormesh(x_sm, y_distorted, f_sm[:-1, :-1], edgecolors='w')
plt.scatter(x_sm, y_distorted, c='k')
<matplotlib.collections.PathCollection at 0x7aa095e55430>
../../_images/5354635f8268bb6b1c8234e3e590fcfe5375fcd1744f78a64376e6e2a6ef9dbb.png

contour / contourf#

fig, ax = plt.subplots(figsize=(12, 5), ncols=2)

# same thing!
ax[0].contour(x1d, y1d, f)
ax[1].contour(xx, yy, f)
<matplotlib.contour.QuadContourSet at 0x7aa08f65c290>
../../_images/bbd54e26e747c179168246bf0234f918d3500779018add723664727226bf496e.png
fig, ax = plt.subplots(figsize=(12, 5), ncols=2)

c0 = ax[0].contour(xx, yy, f, 5)
c1 = ax[1].contour(xx, yy, f, 20)

plt.clabel(c0, fmt='%2.1f')
plt.colorbar(c1, ax=ax[1])
<matplotlib.colorbar.Colorbar at 0x7aa08f53ddc0>
../../_images/a1fb3a8832bf03ea74df63612b2d279951ae40185bf26e149d6a38f8eeab4870.png
fig, ax = plt.subplots(figsize=(12, 5), ncols=2)

clevels = np.arange(-1, 1, 0.2) + 0.1

cf0 = ax[0].contourf(xx, yy, f, clevels, cmap='RdBu_r', extend='both')
cf1 = ax[1].contourf(xx, yy, f, clevels, cmap='inferno', extend='both')

fig.colorbar(cf0, ax=ax[0])
fig.colorbar(cf1, ax=ax[1])
plt.savefig('test_fig.png')
../../_images/1bd8ca22aa10bde0189a516db5e03b85ff35f83b3f78350b75f2159cda8737d5.png

quiver#

u = -np.cos(xx) * np.cos(yy)
v = -np.sin(xx) * np.sin(yy)

fig, ax = plt.subplots(figsize=(12, 7))
ax.contour(xx, yy, f, clevels, cmap='RdBu_r', extend='both', zorder=0)
ax.quiver(xx[::4, ::4], yy[::4, ::4],
           u[::4, ::4], v[::4, ::4], zorder=1)
<matplotlib.quiver.Quiver at 0x7aa08f399370>
../../_images/43bed58a214f425ab705684c0fa3df8c1116423b2f234474cbfde8ec7e5f0cf8.png

streamplot#

fig, ax = plt.subplots(figsize=(12, 7))
ax.streamplot(xx, yy, u, v, density=2, color=(u**2 + v**2))
<matplotlib.streamplot.StreamplotSet at 0x7aa08f399eb0>
../../_images/5d009581444ede888e03c9d0332ad6775e4814eab3f0ff89103166c353f4a97d.png