A document from MCS 275 Spring 2022, instructor David Dumas. You can also get the notebook file.

Quick intro to matplotlib

MCS 275 Spring 2022 - David Dumas

This is a quick tour of basic plotting with matplotlib. For more detail see:

You can install matplotlib on your own machine (e.g. python3 -m pip install matplotlib).
Matplotlib is often most used in a notebook setting.

Matplotlib and numpy are pre-installed in the Google Colab notebook environment.

Important shift in thinking

When working with functions, you are probably used to thinking of x as either a variable or a number.

When working with numpy, most things are vectorized. Taking advantage of that idea, instead of x being just a number, in this notebook we will instead use that name for vector containing all the numbers we plan to consider.

We could call it something else, like xvalues or xvec, but in this notebook we chose the shortest reasonable name, simply x.

Import matplotlib

In [4]:
# This submodule is all you need in most cases, and plt is a common
# abbreviated name to use for it.
import matplotlib.pyplot as plt

# For a very few things you need access to the parent module
# This also lets us check the installed version.
# Uncomment the next two lines to import that.

#   import matplotlib as mpl
#   mpl.__version__
In [5]:
import numpy as np
In [ ]:
#plt.style.available
In [ ]:
#plt.style.use("seaborn-whitegrid")

First line plot

In [6]:
# Let's plot y=sin(x)

# make a vector of 100 evenly spaced floats between 0 and 4pi
x = np.linspace(0,4*np.pi,100)

#y = np.array( [ np.sin(t) for t in x] ) # Don't do this, please
y = np.sin(x)  # Instead, do this.

plt.figure(figsize=(8,6))  # begin a new figure ( might contain many plots )

plt.plot(x,y)  # plot( vector of x vals, vector of y vals )
plt.title("Sine function")
plt.xlabel("x")
plt.ylabel("sin(x)")

plt.show() # show it to me.
In [7]:
# Let's plot y=sin(x)

# make a vector of 100 evenly spaced floats between 0 and 4pi
x = np.linspace(0,4*np.pi,100)

#y = np.array( [ np.sin(t) for t in x] ) # Don't do this, please
y = np.sin(x)  # Instead, do this.

plt.figure(figsize=(8,6))  # begin a new figure ( might contain many plots )

plt.plot(x,y,marker="x")  # plot( vector of x vals, vector of y vals )
plt.title("Sine function")
plt.xlabel("x")
plt.ylabel("sin(x)")

plt.show() # show it to me.
In [8]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
plt.plot(x,y)
plt.show()

Multiple plots and styling

In [13]:
x = np.linspace(start=-3,stop=3,num=500)

y = np.exp(-x*x)  # f(x) = e^{-x^2}
y2 = 1.5*np.exp(-3*(x-2)**2) # g(x) = 1.5* e^{-3(x-2)^2}
y3 = 0.8*np.exp(-4*(x+1)**2) # h(x) = 0.8* e^{-4(x+1)^2}

plt.figure(figsize=(8,6))

# Multiple calls to plt.plot all end up on the same set of axes
plt.plot(x,y,color="red",label="Stewart")
plt.plot(x,y2,color="pink",label="Tina")
plt.plot(x,y3,color="#C4E434",label="29.5")  # hex RRGGBB
plt.plot(x,0.5+0.7*np.sin(25*x),label="ruiner")

plt.legend()


plt.show()
In [45]:
x = np.linspace(start=-3,stop=3,num=50)

y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)

plt.plot(x,y,color="#FF0000")
plt.plot(x,y2,linewidth=3,linestyle="dashed")
plt.plot(x,y3,linestyle="",marker="o")

plt.show()

Adjusting axes

In [24]:
x = np.linspace(start=-3,stop=3,num=100)

y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)

x2 = np.linspace(start=-2,stop=4,num=150)
y4 = 0.5*np.exp(-(x2-2)**2)*np.sin(6*x2)

plt.plot(x,y)
plt.plot(x,y2)
plt.plot(x,y3)

plt.plot(x2,y4)

plt.xlim(-1,3)
plt.ylim(0,1.6)

plt.show()

Parametric plot

In [51]:
t = np.linspace(0,2*np.pi,300)
x = np.cos(t)
y = np.sin(t)
plt.plot(x,y)
plt.plot(np.sin(2*t),np.cos(t))
plt.axis("equal")  # aspect ratio of plot matches aspect ratio of limits
Out[51]:
(-1.0999855104238394,
 1.0999993100201828,
 -1.0999855104238394,
 1.0999993100201828)

Line plots can lie

In [53]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.tan(x)
plt.plot(x,y)
# TODO: Fix.
Out[53]:
[<matplotlib.lines.Line2D at 0x7f1f0334e8e0>]

Adding a legend

In [ ]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.figure(figsize=(8,6))
plt.plot(x,y,label="$e^{-x^2}$")
plt.plot(x,y2,label="$1.5e^{-3(x-2)^2}$")
plt.plot(x,y3,label="$0.8e^{-4(x+1)^2}$")
plt.legend()
plt.show()
plt.savefig("three_gaussians.png",dpi=300)
plt.savefig("three_gaussians.pdf")
In [ ]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.figure(figsize=(8,6))
plt.plot(x,y,label="$e^{-x^2}$",color="orange",linestyle="dashed",linewidth=5)
plt.plot(x,y2,label="$1.5e^{-3(x-2)^2}$",color="#FF0080")
plt.plot(x,y3,label="$0.8e^{-4(x+1)^2}$",linestyle="dotted")
plt.legend()
plt.show()
In [ ]:
x = np.linspace(start=-3,stop=3,num=100)
y = np.exp(-x*x)
y2 = 1.5*np.exp(-3*(x-2)**2)
y3 = 0.8*np.exp(-4*(x+1)**2)
plt.figure(figsize=(8,6))
plt.plot(x,y,label="$e^{-x^2}$",color="orange",linestyle="dashed",linewidth=5)
plt.plot(x,y2,label="$1.5e^{-3(x-2)^2}$",color="#FF0080",marker="*")
plt.plot(x,y3,label="$0.8e^{-4(x+1)^2}$",linestyle="dotted")
plt.legend()
plt.show()

Scatter plots

In [56]:
# Fundamentally, plt.plot shows the same marker symbol
# (same size, shape, color) at each data point
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
plt.plot(n,t,marker="o",linestyle="",color="orange")
plt.show()
In [60]:
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
s = np.array([0.1,0.1,0.1,0.2,0.2,0.5])
c = np.array([1,2,3,5,8,20])
plt.scatter(n,t,s=250*s,c=c,marker="o",cmap="Pastel2")
plt.colorbar()
Out[60]:
<matplotlib.colorbar.Colorbar at 0x7f1f033b53a0>
In [61]:
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
s = np.array([0.1,0.1,0.1,0.2,0.2,0.5])
c = np.array([1,2,3,5,8,20])
plt.scatter(n,t,s=250*s,c=c,marker="o",cmap="seismic")
plt.colorbar()
Out[61]:
<matplotlib.colorbar.Colorbar at 0x7f1f032544f0>
In [63]:
n = np.array([1,1.5,2,2.5,3.5,5])
t = np.array([1.8,2.6,3.5,4.9,8.8,8.2])
s = np.array([0.1,0.1,0.1,0.2,0.2,0.5])
c = np.array(["red","red","red","blue","blue","blue"])
plt.scatter(n,t,s=250*s,c=c,marker="o")
Out[63]:
<matplotlib.collections.PathCollection at 0x7f1f0316e2b0>
In [58]:
plt.colormaps()
Out[58]:
['Accent',
 'Accent_r',
 'Blues',
 'Blues_r',
 'BrBG',
 'BrBG_r',
 'BuGn',
 'BuGn_r',
 'BuPu',
 'BuPu_r',
 'CMRmap',
 'CMRmap_r',
 'Dark2',
 'Dark2_r',
 'GnBu',
 'GnBu_r',
 'Greens',
 'Greens_r',
 'Greys',
 'Greys_r',
 'OrRd',
 'OrRd_r',
 'Oranges',
 'Oranges_r',
 'PRGn',
 'PRGn_r',
 'Paired',
 'Paired_r',
 'Pastel1',
 'Pastel1_r',
 'Pastel2',
 'Pastel2_r',
 'PiYG',
 'PiYG_r',
 'PuBu',
 'PuBuGn',
 'PuBuGn_r',
 'PuBu_r',
 'PuOr',
 'PuOr_r',
 'PuRd',
 'PuRd_r',
 'Purples',
 'Purples_r',
 'RdBu',
 'RdBu_r',
 'RdGy',
 'RdGy_r',
 'RdPu',
 'RdPu_r',
 'RdYlBu',
 'RdYlBu_r',
 'RdYlGn',
 'RdYlGn_r',
 'Reds',
 'Reds_r',
 'Set1',
 'Set1_r',
 'Set2',
 'Set2_r',
 'Set3',
 'Set3_r',
 'Spectral',
 'Spectral_r',
 'Wistia',
 'Wistia_r',
 'YlGn',
 'YlGnBu',
 'YlGnBu_r',
 'YlGn_r',
 'YlOrBr',
 'YlOrBr_r',
 'YlOrRd',
 'YlOrRd_r',
 'afmhot',
 'afmhot_r',
 'autumn',
 'autumn_r',
 'binary',
 'binary_r',
 'bone',
 'bone_r',
 'brg',
 'brg_r',
 'bwr',
 'bwr_r',
 'cividis',
 'cividis_r',
 'cool',
 'cool_r',
 'coolwarm',
 'coolwarm_r',
 'copper',
 'copper_r',
 'cubehelix',
 'cubehelix_r',
 'flag',
 'flag_r',
 'gist_earth',
 'gist_earth_r',
 'gist_gray',
 'gist_gray_r',
 'gist_heat',
 'gist_heat_r',
 'gist_ncar',
 'gist_ncar_r',
 'gist_rainbow',
 'gist_rainbow_r',
 'gist_stern',
 'gist_stern_r',
 'gist_yarg',
 'gist_yarg_r',
 'gnuplot',
 'gnuplot2',
 'gnuplot2_r',
 'gnuplot_r',
 'gray',
 'gray_r',
 'hot',
 'hot_r',
 'hsv',
 'hsv_r',
 'inferno',
 'inferno_r',
 'jet',
 'jet_r',
 'magma',
 'magma_r',
 'nipy_spectral',
 'nipy_spectral_r',
 'ocean',
 'ocean_r',
 'pink',
 'pink_r',
 'plasma',
 'plasma_r',
 'prism',
 'prism_r',
 'rainbow',
 'rainbow_r',
 'seismic',
 'seismic_r',
 'spring',
 'spring_r',
 'summer',
 'summer_r',
 'tab10',
 'tab10_r',
 'tab20',
 'tab20_r',
 'tab20b',
 'tab20b_r',
 'tab20c',
 'tab20c_r',
 'terrain',
 'terrain_r',
 'turbo',
 'turbo_r',
 'twilight',
 'twilight_r',
 'twilight_shifted',
 'twilight_shifted_r',
 'viridis',
 'viridis_r',
 'winter',
 'winter_r']

Scatter plot real world data

CSV with data about meteorites recovered on earth's surface, adapted from NASA dataset:

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import csv
In [4]:
# This cell just grabs the contents of "meteorites.csv"
# and returns it as a dict mapping column names to vectors of column data

import collections

columns = collections.defaultdict(list)
with open("meteorites.csv","r",newline="") as fp:
    rdr = csv.DictReader(fp)
    for row in rdr:
        for k in row:
            columns[k].append(row[k])

for k in columns:
    if k != "year":
        columns[k] = np.array(columns[k]).astype("float64")
    else:
        columns[k] = np.array(columns[k]).astype("int")
In [5]:
plt.figure(figsize=(15,10))
plt.scatter(
    columns["longitude"],
    columns["latitude"],
    s=0.002*columns["mass"]**(0.66),  # s sets area of the dot
    alpha=0.6,
    c=columns["year"],
    cmap="PuOr"
)
plt.colorbar()
plt.show()

Pretty nice. Looks like a world map. Questions:

  • How much influence does population density have on this?
  • What's the likely distribution of actual meteorite landing points?
  • What's that really big one from before 1827 near latitude 75?
In [10]:
plt.figure(figsize=(15,10))
plt.scatter(
    columns["longitude"],
    columns["latitude"],
    s=0.002*columns["mass"]**(0.66),  # s sets area of the dot
    alpha=0.6,
    c=columns["year"],
    cmap="PuOr"
)
plt.axis("equal")
plt.xlim(-140,-60)
plt.ylim(10,60)
plt.colorbar()
plt.show()

Quick demo of annotate

In [7]:
plt.figure(figsize=(15,10))
plt.scatter(
    columns["longitude"],
    columns["latitude"],
    s=0.002*columns["mass"]**(0.66),
    alpha=0.6,  # 60% opaque, 40% transparent dots.
    c=columns["year"],
    cmap="PuOr"
)
plt.colorbar()
plt.annotate("Cape York Meteorite (1818)",
             xy=(-64.933,76.13),   # Point we're annotating
             xycoords='data',      # inform matlab these coords are in data units
             xytext=(0, 15),  # Where the text goes
             textcoords='offset points',  # inform matlab of units and origin for the coords on prev line
                                          # (units = points, origin = the point being annotated)
             horizontalalignment='center',
             verticalalignment='bottom',
            )

plt.show()

Contour and density plots

In [11]:
x = np.linspace(-3,3,100)
y = np.linspace(-2,2,80)
xx,yy = np.meshgrid(x,y)
In [24]:
# f(x,y) = x**3 - 8x + 3*y**2 + 0.5*y**3
zz = xx**3 - 8*xx + 3*yy**2 + 0.5*yy**3  # 80x100 matrix of values of f on the grid
In [25]:
# f(x,y) = 0.2?
plt.figure(figsize=(8,6))
plt.contour(xx,yy,zz,[0.2])
plt.show()
In [31]:
# Contour plot
plt.figure(figsize=(8,6))
plt.contour(xx,yy,zz)
plt.colorbar()
Out[31]:
<matplotlib.colorbar.Colorbar at 0x7fc2fcd013a0>
In [26]:
# Filled contour plot
plt.figure(figsize=(8,6))
plt.contourf(xx,yy,zz)
plt.colorbar()
Out[26]:
<matplotlib.colorbar.Colorbar at 0x7fc2fd49a190>

Adding labels to contours

plt.clabel adds labels to an existing contour plot. Its argument is the return value of a previous call to plt.contour.

In [27]:
plt.figure(figsize=(8,6))
contours = plt.contour(xx,yy,zz,15,cmap="magma")
plt.title("Contour plot")
plt.clabel(contours) # add inline labels to the contours
plt.colorbar()
Out[27]:
<matplotlib.colorbar.Colorbar at 0x7fc2fd3c65e0>

Density plots with plt.imshow

In [28]:
plt.figure(figsize=(8,6))
plt.imshow(zz,extent=[np.min(x),np.max(x),np.min(y),np.max(y)],origin="lower")
# origin="lower" means the first row of zz appears at the bottom of the plot.
# That's correct since our meshgrid has smallest y values in the first row.
plt.title("Density plot")
plt.colorbar()
Out[28]:
<matplotlib.colorbar.Colorbar at 0x7fc2fd559a60>
In [29]:
plt.figure(figsize=(8,6))
contours = plt.contour(xx,yy,zz,15,colors="white")
plt.title("Contour and density plot")
plt.clabel(contours) # add inline labels to the contours
plt.imshow(zz,extent=[np.min(x),np.max(x),np.min(y),np.max(y)],origin="lower")
plt.colorbar()
Out[29]:
<matplotlib.colorbar.Colorbar at 0x7fc2fd4f3310>

Aside: UIC's colors

In [30]:
# This is adapted from an example in the matplotlib docs:
# https://matplotlib.org/stable/gallery/color/named_colors.html

# This uses lots of matplotlib features we don't cover in MCS 275!

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

def plot_colortable(colors, title, sort_colors=True, emptycols=0):

    cell_width = 212
    cell_height = 48
    swatch_width = 64
    margin = 12
    topmargin = 56

    # Sort colors by hue, saturation, value and name.
    if sort_colors is True:
        by_hsv = sorted((tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(color))),
                         name)
                        for name, color in colors.items())
        names = [name for hsv, name in by_hsv]
    else:
        names = list(colors)

    n = len(names)
    ncols = 4 - emptycols
    nrows = n // ncols + int(n % ncols > 0)

    width = cell_width * 4 + 2 * margin
    height = cell_height * nrows + margin + topmargin
    dpi = 72

    fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
    fig.subplots_adjust(margin/width, margin/height,
                        (width-margin)/width, (height-topmargin)/height)
    ax.set_xlim(0, cell_width * 4)
    ax.set_ylim(cell_height * (nrows-0.5), -cell_height/2.)
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_axis_off()
    ax.set_title(title, fontsize=24, loc="left", pad=10)

    for i, name in enumerate(names):
        row = i % nrows
        col = i // nrows
        y = row * cell_height

        swatch_start_x = cell_width * col
        swatch_end_x = cell_width * col + swatch_width
        text_pos_x = cell_width * col + swatch_width + 7

        ax.text(text_pos_x, y, name+"\n"+colors[name], fontsize=14,
                horizontalalignment='left',
                verticalalignment='center')

        ax.hlines(y, swatch_start_x, swatch_end_x,
                  color=colors[name], linewidth=28)

    return fig

plot_colortable(
    {
        "Fire Engine Red":"#D50032",
        "Navy Pier Blue":"#001E62"
    },
    "Primary",
    sort_colors=False,
    emptycols=1
)

plot_colortable(
    {
        "Chicago Blue":"#41B6E6",
        "Champions Gold":"#FFBF3F",
        "UI Health Sky Blue":"#0065AD",
        "Expo White":"#F2F7EB",
        "Steel Gray":"#333333",
    },
    "Primary",
    sort_colors=False,
    emptycols=1
)

plt.show()