A document from MCS 275 Spring 2022, instructor David Dumas. You can also get the notebook file.

Quick intro to matplotlib

MCS 275 Spring 2022 - David Dumas

This is a quick tour of basic plotting with matplotlib. For more detail see:

You can install matplotlib on your own machine (e.g. python3 -m pip install matplotlib).
Matplotlib is often most used in a notebook setting.

Matplotlib and numpy are pre-installed in the Google Colab notebook environment.

Important shift in thinking

When working with functions, you are probably used to thinking of x as either a variable or a number.

When working with numpy, most things are vectorized. Taking advantage of that idea, instead of x being just a number, in this notebook we will instead use that name for vector containing all the numbers we plan to consider.

We could call it something else, like xvalues or xvec, but in this notebook we chose the shortest reasonable name, simply x.

Import matplotlib

In [4]:
# This submodule is all you need in most cases, and plt is a common
# abbreviated name to use for it.
import matplotlib.pyplot as plt

# For a very few things you need access to the parent module
# This also lets us check the installed version.
# Uncomment the next two lines to import that.

#   import matplotlib as mpl
#   mpl.__version__
In [5]:
import numpy as np
In [ ]:
#plt.style.available
In [ ]:
#plt.style.use("seaborn-whitegrid")

First line plot

In [6]:
# Let's plot y=sin(x)

# make a vector of 100 evenly spaced floats between 0 and 4pi
x = np.linspace(0,4*np.pi,100)

#y = np.array( [ np.sin(t) for t in x] ) # Don't do this, please
y = np.sin(x)  # Instead, do this.

plt.figure(figsize=(8,6))  # begin a new figure ( might contain many plots )

plt.plot(x,y)  # plot( vector of x vals, vector of y vals )
plt.title("Sine function")
plt.xlabel("x")
plt.ylabel("sin(x)")

plt.show() # show it to me.
In [7]:
# Let's plot y=sin(x)

# make a vector of 100 evenly spaced floats between 0 and 4pi
x = np.linspace(0,4*np.pi,100)

#y = np.array( [ np.sin(t) for t in x] ) # Don't do this, please
y = np.sin(x)  # Instead, do this.

plt.figure(figsize=(8,6))  # begin a new figure ( might contain many plots )

plt.plot(x,y,marker="x")  # plot( vector of x vals, vector of y vals )
plt.title("Sine function")
plt.xlabel("x")
plt.ylabel("sin(x)")

plt.show() # show it to me.
In [8]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
plt.plot(x,y)
plt.show()

Multiple plots and styling

In [13]:
x = np.linspace(start=-3,stop=3,num=500)

y = np.exp(-x*x)  # f(x) = e^{-x^2}
y2 = 1.5*np.exp(-3*(x-2)**2) # g(x) = 1.5* e^{-3(x-2)^2}
y3 = 0.8*np.exp(-4*(x+1)**2) # h(x) = 0.8* e^{-4(x+1)^2}

plt.figure(figsize=(8,6))

# Multiple calls to plt.plot all end up on the same set of axes
plt.plot(x,y,color="red",label="Stewart")
plt.plot(x,y2,color="pink",label="Tina")
plt.plot(x,y3,color="#C4E434",label="29.5")  # hex RRGGBB
plt.plot(x,0.5+0.7*np.sin(25*x),label="ruiner")

plt.legend()


plt.show()
In [45]:
x = np.linspace(start=-3,stop=3,num=50)

y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)

plt.plot(x,y,color="#FF0000")
plt.plot(x,y2,linewidth=3,linestyle="dashed")
plt.plot(x,y3,linestyle="",marker="o")

plt.show()

Adjusting axes

In [24]:
x = np.linspace(start=-3,stop=3,num=100)

y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)

x2 = np.linspace(start=-2,stop=4,num=150)
y4 = 0.5*np.exp(-(x2-2)**2)*np.sin(6*x2)

plt.plot(x,y)
plt.plot(x,y2)
plt.plot(x,y3)

plt.plot(x2,y4)

plt.xlim(-1,3)
plt.ylim(0,1.6)

plt.show()

Parametric plot

In [51]:
t = np.linspace(0,2*np.pi,300)
x = np.cos(t)
y = np.sin(t)
plt.plot(x,y)
plt.plot(np.sin(2*t),np.cos(t))
plt.axis("equal")  # aspect ratio of plot matches aspect ratio of limits
Out[51]:
(-1.0999855104238394,
 1.0999993100201828,
 -1.0999855104238394,
 1.0999993100201828)

Line plots can lie

In [53]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.tan(x)
plt.plot(x,y)
# TODO: Fix.
Out[53]:
[<matplotlib.lines.Line2D at 0x7f1f0334e8e0>]

Adding a legend

In [ ]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.figure(figsize=(8,6))
plt.plot(x,y,label="$e^{-x^2}$")
plt.plot(x,y2,label="$1.5e^{-3(x-2)^2}$")
plt.plot(x,y3,label="$0.8e^{-4(x+1)^2}$")
plt.legend()
plt.show()
plt.savefig("three_gaussians.png",dpi=300)
plt.savefig("three_gaussians.pdf")
In [ ]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.figure(figsize=(8,6))
plt.plot(x,y,label="$e^{-x^2}$",color="orange",linestyle="dashed",linewidth=5)
plt.plot(x,y2,label="$1.5e^{-3(x-2)^2}$",color="#FF0080")
plt.plot(x,y3,label="$0.8e^{-4(x+1)^2}$",linestyle="dotted")
plt.legend()
plt.show()
In [ ]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.figure(figsize=(8,6))
plt.plot(x,y,label="$e^{-x^2}$",color="orange",linestyle="dashed",linewidth=5)
plt.plot(x,y2,label="$1.5e^{-3(x-2)^2}$",color="#FF0080",marker="*")
plt.plot(x,y3,label="$0.8e^{-4(x+1)^2}$",linestyle="dotted")
plt.legend()
plt.show()

Scatter plots

In [56]:
# Fundamentally, plt.plot shows the same marker symbol
# (same size, shape, color) at each data point
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
plt.plot(n,t,marker="o",linestyle="",color="orange")
plt.show()
In [60]:
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
s = np.array([0.1,0.1,0.1,0.2,0.2,0.5])
c = np.array([1,2,3,5,8,20])
plt.scatter(n,t,s=250*s,c=c,marker="o",cmap="Pastel2")
plt.colorbar()
Out[60]:
<matplotlib.colorbar.Colorbar at 0x7f1f033b53a0>
In [61]:
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
s = np.array([0.1,0.1,0.1,0.2,0.2,0.5])
c = np.array([1,2,3,5,8,20])
plt.scatter(n,t,s=250*s,c=c,marker="o",cmap="seismic")
plt.colorbar()
Out[61]:
<matplotlib.colorbar.Colorbar at 0x7f1f032544f0>
In [63]:
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
s = np.array([0.1,0.1,0.1,0.2,0.2,0.5])
c = np.array(["red","red","red","blue","blue","blue"])
plt.scatter(n,t,s=250*s,c=c,marker="o")
Out[63]:
<matplotlib.collections.PathCollection at 0x7f1f0316e2b0>
In [58]:
plt.colormaps()
Out[58]:
['Accent',
 'Accent_r',
 'Blues',
 'Blues_r',
 'BrBG',
 'BrBG_r',
 'BuGn',
 'BuGn_r',
 'BuPu',
 'BuPu_r',
 'CMRmap',
 'CMRmap_r',
 'Dark2',
 'Dark2_r',
 'GnBu',
 'GnBu_r',
 'Greens',
 'Greens_r',
 'Greys',
 'Greys_r',
 'OrRd',
 'OrRd_r',
 'Oranges',
 'Oranges_r',
 'PRGn',
 'PRGn_r',
 'Paired',
 'Paired_r',
 'Pastel1',
 'Pastel1_r',
 'Pastel2',
 'Pastel2_r',
 'PiYG',
 'PiYG_r',
 'PuBu',
 'PuBuGn',
 'PuBuGn_r',
 'PuBu_r',
 'PuOr',
 'PuOr_r',
 'PuRd',
 'PuRd_r',
 'Purples',
 'Purples_r',
 'RdBu',
 'RdBu_r',
 'RdGy',
 'RdGy_r',
 'RdPu',
 'RdPu_r',
 'RdYlBu',
 'RdYlBu_r',
 'RdYlGn',
 'RdYlGn_r',
 'Reds',
 'Reds_r',
 'Set1',
 'Set1_r',
 'Set2',
 'Set2_r',
 'Set3',
 'Set3_r',
 'Spectral',
 'Spectral_r',
 'Wistia',
 'Wistia_r',
 'YlGn',
 'YlGnBu',
 'YlGnBu_r',
 'YlGn_r',
 'YlOrBr',
 'YlOrBr_r',
 'YlOrRd',
 'YlOrRd_r',
 'afmhot',
 'afmhot_r',
 'autumn',
 'autumn_r',
 'binary',
 'binary_r',
 'bone',
 'bone_r',
 'brg',
 'brg_r',
 'bwr',
 'bwr_r',
 'cividis',
 'cividis_r',
 'cool',
 'cool_r',
 'coolwarm',
 'coolwarm_r',
 'copper',
 'copper_r',
 'cubehelix',
 'cubehelix_r',
 'flag',
 'flag_r',
 'gist_earth',
 'gist_earth_r',
 'gist_gray',
 'gist_gray_r',
 'gist_heat',
 'gist_heat_r',
 'gist_ncar',
 'gist_ncar_r',
 'gist_rainbow',
 'gist_rainbow_r',
 'gist_stern',
 'gist_stern_r',
 'gist_yarg',
 'gist_yarg_r',
 'gnuplot',
 'gnuplot2',
 'gnuplot2_r',
 'gnuplot_r',
 'gray',
 'gray_r',
 'hot',
 'hot_r',
 'hsv',
 'hsv_r',
 'inferno',
 'inferno_r',
 'jet',
 'jet_r',
 'magma',
 'magma_r',
 'nipy_spectral',
 'nipy_spectral_r',
 'ocean',
 'ocean_r',
 'pink',
 'pink_r',
 'plasma',
 'plasma_r',
 'prism',
 'prism_r',
 'rainbow',
 'rainbow_r',
 'seismic',
 'seismic_r',
 'spring',
 'spring_r',
 'summer',
 'summer_r',
 'tab10',
 'tab10_r',
 'tab20',
 'tab20_r',
 'tab20b',
 'tab20b_r',
 'tab20c',
 'tab20c_r',
 'terrain',
 'terrain_r',
 'turbo',
 'turbo_r',
 'twilight',
 'twilight_r',
 'twilight_shifted',
 'twilight_shifted_r',
 'viridis',
 'viridis_r',
 'winter',
 'winter_r']

Scatter plot real world data

CSV with data about meteorites recovered on earth's surface, adapted from NASA dataset:

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import csv
In [4]:
# This cell just grabs the contents of "meteorites.csv"
# and returns it as a dict mapping column names to vectors of column data

import collections

columns = collections.defaultdict(list)
with open("meteorites.csv","r",newline="") as fp:
    rdr = csv.DictReader(fp)
    for row in rdr:
        for k in row:
            columns[k].append(row[k])

for k in columns:
    if k != "year":
        columns[k] = np.array(columns[k]).astype("float64")
    else:
        columns[k] = np.array(columns[k]).astype("int")
In [5]:
plt.figure(figsize=(15,10))
plt.scatter(
    columns["longitude"],
    columns["latitude"],
    s=0.002*columns["mass"]**(0.66),  # s sets area of the dot
    alpha=0.6,
    c=columns["year"],
    cmap="PuOr"
)
plt.colorbar()
plt.show()

Pretty nice. Looks like a world map. Questions:

  • How much influence does population density have on this?
  • What's the likely distribution of actual meteorite landing points?
  • What's that really big one from before 1827 near latitude 75?
In [10]:
plt.figure(figsize=(15,10))
plt.scatter(
    columns["longitude"],
    columns["latitude"],
    s=0.002*columns["mass"]**(0.66),  # s sets area of the dot
    alpha=0.6,
    c=columns["year"],
    cmap="PuOr"
)
plt.axis("equal")
plt.xlim(-140,-60)
plt.ylim(10,60)
plt.colorbar()
plt.show()

Quick demo of annotate

In [7]:
plt.figure(figsize=(15,10))
plt.scatter(
    columns["longitude"],
    columns["latitude"],
    s=0.002*columns["mass"]**(0.66),
    alpha=0.6,  # 60% opaque, 40% transparent dots.
    c=columns["year"],
    cmap="PuOr"
)
plt.colorbar()
plt.annotate("Cape York Meteorite (1818)",
             xy=(-64.933,76.13),   # Point we're annotating
             xycoords='data',      # inform matlab these coords are in data units
             xytext=(0, 15),  # Where the text goes
             textcoords='offset points',  # inform matlab of units and origin for the coords on prev line
                                          # (units = points, origin = the point being annotated)
             horizontalalignment='center',
             verticalalignment='bottom',
            )

plt.show()

Contour and density plots

In [11]:
x = np.linspace(-3,3,100)
y = np.linspace(-2,2,80)
xx,yy = np.meshgrid(x,y)
In [24]:
# f(x,y) = x**3 - 8x + 3*y**2 + 0.5*y**3
zz = xx**3 - 8*xx + 3*yy**2 + 0.5*yy**3  # 80x100 matrix of values of f on the grid
In [25]:
# f(x,y) = 0.2?
plt.figure(figsize=(8,6))
plt.contour(xx,yy,zz,[0.2])
plt.show()
In [31]:
# Contour plot
plt.figure(figsize=(8,6))
plt.contour(xx,yy,zz)
plt.colorbar()
Out[31]:
<matplotlib.colorbar.Colorbar at 0x7fc2fcd013a0>
In [26]:
# Filled contour plot
plt.figure(figsize=(8,6))
plt.contourf(xx,yy,zz)
plt.colorbar()
Out[26]:
<matplotlib.colorbar.Colorbar at 0x7fc2fd49a190>

Adding labels to contours

plt.clabel adds labels to an existing contour plot. Its argument is the return value of a previous call to plt.contour.

In [27]:
plt.figure(figsize=(8,6))
contours = plt.contour(xx,yy,zz,15,cmap="magma")
plt.title("Contour plot")
plt.clabel(contours) # add inline labels to the contours
plt.colorbar()
Out[27]:
<matplotlib.colorbar.Colorbar at 0x7fc2fd3c65e0>