Basic Plotting with matplotlib

You can show matplotlib figures directly in the notebook by using the %matplotlib notebook and %matplotlib inline magic commands.

%matplotlib notebook provides an interactive environment.

In [2]:
%matplotlib notebook
In [4]:
import matplotlib as mpl
mpl.get_backend()
Out[4]:
'nbAgg'
In [3]:
import matplotlib.pyplot as plt
plt.plot?
In [4]:
# because the default is the line style '-', 
# nothing will be shown if we only pass in one point (3,2)
plt.plot(3, 2)
Out[4]:
[<matplotlib.lines.Line2D at 0x7fa1870a7898>]
In [5]:
# we can pass in '.' to plt.plot to indicate that we want
# the point (3,2) to be indicated with a marker '.'
plt.plot(3, 2, '.')
Out[5]:
[<matplotlib.lines.Line2D at 0x7fa1870c4080>]

Let's see how to make a plot without using the scripting layer.

In [6]:
# First let's set the backend without using mpl.use() from the scripting layer
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure

# create a new figure
fig = Figure()

# associate fig with the backend
canvas = FigureCanvasAgg(fig)

# add a subplot to the fig
ax = fig.add_subplot(111)

# plot the point (3,2)
ax.plot(3, 2, '.')

# save the figure to test.png
# you can see this figure in your Jupyter workspace afterwards by going to
# https://hub.coursera-notebooks.org/
canvas.print_png('test.png')

We can use html cell magic to display the image.

In [7]:
%%html
<img src='test.png' />
In [8]:
# create a new figure
plt.figure()

# plot the point (3,2) using the circle marker
plt.plot(3, 2, 'o')

# get the current axes
ax = plt.gca()

# Set axis properties [xmin, xmax, ymin, ymax]
ax.axis([0,6,0,10])
Out[8]:
[0, 6, 0, 10]
In [9]:
# create a new figure
plt.figure()

# plot the point (1.5, 1.5) using the circle marker
plt.plot(1.5, 1.5, 'o')
# plot the point (2, 2) using the circle marker
plt.plot(2, 2, 'o')
# plot the point (2.5, 2.5) using the circle marker
plt.plot(2.5, 2.5, 'o')
Out[9]:
[<matplotlib.lines.Line2D at 0x7fa18484dba8>]
In [10]:
# get current axes
ax = plt.gca()
# get all the child objects the axes contains
ax.get_children()
Out[10]:
[<matplotlib.lines.Line2D at 0x7fa18484da20>,
 <matplotlib.lines.Line2D at 0x7fa1848829b0>,
 <matplotlib.lines.Line2D at 0x7fa18484dba8>,
 <matplotlib.spines.Spine at 0x7fa1848992b0>,
 <matplotlib.spines.Spine at 0x7fa1848994e0>,
 <matplotlib.spines.Spine at 0x7fa1848996d8>,
 <matplotlib.spines.Spine at 0x7fa1848998d0>,
 <matplotlib.axis.XAxis at 0x7fa184899a90>,
 <matplotlib.axis.YAxis at 0x7fa184824860>,
 <matplotlib.text.Text at 0x7fa1848c2a58>,
 <matplotlib.text.Text at 0x7fa1848c2c50>,
 <matplotlib.text.Text at 0x7fa1848c24e0>,
 <matplotlib.patches.Rectangle at 0x7fa1848c2518>]

Scatterplots

In [11]:
import numpy as np

x = np.array([1,2,3,4,5,6,7,8])
y = x

plt.figure()
plt.scatter(x, y) # similar to plt.plot(x, y, '.'), but the underlying child objects in the axes are not Line2D
Out[11]:
<matplotlib.collections.PathCollection at 0x7fa17d890d30>
In [12]:
import numpy as np

x = np.array([1,2,3,4,5,6,7,8])
y = x

# create a list of colors for each point to have
# ['green', 'green', 'green', 'green', 'green', 'green', 'green', 'red']
colors = ['green']*(len(x)-1)
colors.append('red')

plt.figure()

# plot the point with size 100 and chosen colors
plt.scatter(x, y, s=100, c=colors)
Out[12]:
<matplotlib.collections.PathCollection at 0x7fa17051d128>
In [13]:
# convert the two lists into a list of pairwise tuples
zip_generator = zip([1,2,3,4,5], [6,7,8,9,10])

print(list(zip_generator))
# the above prints:
# [(1, 6), (2, 7), (3, 8), (4, 9), (5, 10)]

zip_generator = zip([1,2,3,4,5], [6,7,8,9,10])
# The single star * unpacks a collection into positional arguments
print(*zip_generator)
# the above prints:
# (1, 6) (2, 7) (3, 8) (4, 9) (5, 10)
[(1, 6), (2, 7), (3, 8), (4, 9), (5, 10)]
(1, 6) (2, 7) (3, 8) (4, 9) (5, 10)
In [14]:
# use zip to convert 5 tuples with 2 elements each to 2 tuples with 5 elements each
print(list(zip((1, 6), (2, 7), (3, 8), (4, 9), (5, 10))))
# the above prints:
# [(1, 2, 3, 4, 5), (6, 7, 8, 9, 10)]


zip_generator = zip([1,2,3,4,5], [6,7,8,9,10])
# let's turn the data back into 2 lists
x, y = zip(*zip_generator) # This is like calling zip((1, 6), (2, 7), (3, 8), (4, 9), (5, 10))
print(x)
print(y)
# the above prints:
# (1, 2, 3, 4, 5)
# (6, 7, 8, 9, 10)
[(1, 2, 3, 4, 5), (6, 7, 8, 9, 10)]
(1, 2, 3, 4, 5)
(6, 7, 8, 9, 10)
In [15]:
plt.figure()
# plot a data series 'Tall students' in red using the first two elements of x and y
plt.scatter(x[:2], y[:2], s=100, c='red', label='Tall students')
# plot a second data series 'Short students' in blue using the last three elements of x and y 
plt.scatter(x[2:], y[2:], s=100, c='blue', label='Short students')
Out[15]:
<matplotlib.collections.PathCollection at 0x7fa17048e438>
In [16]:
# add a label to the x axis
plt.xlabel('The number of times the child kicked a ball')
# add a label to the y axis
plt.ylabel('The grade of the student')
# add a title
plt.title('Relationship between ball kicking and grades')
Out[16]:
<matplotlib.text.Text at 0x7fa1704e73c8>
In [17]:
# add a legend (uses the labels from plt.scatter)
plt.legend()
Out[17]:
<matplotlib.legend.Legend at 0x7fa17d830fd0>
In [18]:
# add the legend to loc=4 (the lower right hand corner), also gets rid of the frame and adds a title
plt.legend(loc=4, frameon=False, title='Legend')
Out[18]:
<matplotlib.legend.Legend at 0x7fa17044aa90>
In [19]:
# get children from current axes (the legend is the second to last item in this list)
plt.gca().get_children()
Out[19]:
[<matplotlib.collections.PathCollection at 0x7fa170488898>,
 <matplotlib.collections.PathCollection at 0x7fa17048e438>,
 <matplotlib.spines.Spine at 0x7fa170535898>,
 <matplotlib.spines.Spine at 0x7fa170535438>,
 <matplotlib.spines.Spine at 0x7fa170530cc0>,
 <matplotlib.spines.Spine at 0x7fa1705304e0>,
 <matplotlib.axis.XAxis at 0x7fa170530048>,
 <matplotlib.axis.YAxis at 0x7fa17d832cc0>,
 <matplotlib.text.Text at 0x7fa1704e73c8>,
 <matplotlib.text.Text at 0x7fa1704e7438>,
 <matplotlib.text.Text at 0x7fa1704e74a8>,
 <matplotlib.legend.Legend at 0x7fa17044aa90>,
 <matplotlib.patches.Rectangle at 0x7fa1704e74e0>]
In [20]:
# get the legend from the current axes
legend = plt.gca().get_children()[-2]
In [21]:
# you can use get_children to navigate through the child artists
legend.get_children()[0].get_children()[1].get_children()[0].get_children()
Out[21]:
[<matplotlib.offsetbox.HPacker at 0x7fa170454c18>,
 <matplotlib.offsetbox.HPacker at 0x7fa170454c88>]
In [22]:
# import the artist class from matplotlib
from matplotlib.artist import Artist

def rec_gc(art, depth=0):
    if isinstance(art, Artist):
        # increase the depth for pretty printing
        print("  " * depth + str(art))
        for child in art.get_children():
            rec_gc(child, depth+2)

# Call this function on the legend artist to see what the legend is made up of
rec_gc(plt.legend())
Legend
    <matplotlib.offsetbox.VPacker object at 0x7fa170464978>
        <matplotlib.offsetbox.TextArea object at 0x7fa170464668>
            Text(0,0,'None')
        <matplotlib.offsetbox.HPacker object at 0x7fa17045ebe0>
            <matplotlib.offsetbox.VPacker object at 0x7fa17045ec18>
                <matplotlib.offsetbox.HPacker object at 0x7fa1704645c0>
                    <matplotlib.offsetbox.DrawingArea object at 0x7fa17045ee48>
                        <matplotlib.collections.PathCollection object at 0x7fa17045efd0>
                    <matplotlib.offsetbox.TextArea object at 0x7fa17045ec50>
                        Text(0,0,'Tall students')
                <matplotlib.offsetbox.HPacker object at 0x7fa170464630>
                    <matplotlib.offsetbox.DrawingArea object at 0x7fa170464390>
                        <matplotlib.collections.PathCollection object at 0x7fa170464550>
                    <matplotlib.offsetbox.TextArea object at 0x7fa170464080>
                        Text(0,0,'Short students')
    FancyBboxPatch(0,0;1x1)

Line Plots

In [6]:
import numpy as np

linear_data = np.array([1,2,3,4,5,6,7,8])
exponential_data = linear_data**2

plt.figure()
# plot the linear data and the exponential data
plt.plot(linear_data, '-o', exponential_data, '-o')
Out[6]:
[<matplotlib.lines.Line2D at 0x7f27a0b12cc0>,
 <matplotlib.lines.Line2D at 0x7f27a0b12e10>]
In [7]:
# plot another series with a dashed red line
plt.plot([22,44,55], '--r')
Out[7]:
[<matplotlib.lines.Line2D at 0x7f27a0ae2358>]
In [8]:
plt.xlabel('Some data')
plt.ylabel('Some other data')
plt.title('A title')
# add a legend with legend entries (because we didn't have labels when we plotted the data series)
plt.legend(['Baseline', 'Competition', 'Us'])
Out[8]:
<matplotlib.legend.Legend at 0x7f27a2ba9978>
In [9]:
# fill the area between the linear data and exponential data
plt.gca().fill_between(range(len(linear_data)), 
                       linear_data, exponential_data, 
                       facecolor='blue', 
                       alpha=0.25)
Out[9]:
<matplotlib.collections.PolyCollection at 0x7f27a0829438>

Let's try working with dates!

In [10]:
plt.figure()

observation_dates = np.arange('2017-01-01', '2017-01-09', dtype='datetime64[D]')

plt.plot(observation_dates, linear_data, '-o',  observation_dates, exponential_data, '-o')
Out[10]:
[<matplotlib.lines.Line2D at 0x7f27a07ddfd0>,
 <matplotlib.lines.Line2D at 0x7f27a07dd208>]

Let's try using pandas

In [8]:
import pandas as pd

plt.figure()
observation_dates = np.arange('2017-01-01', '2017-01-09', dtype='datetime64[D]')
observation_dates = map(pd.to_datetime, observation_dates) # trying to plot a map will result in an error
plt.plot(observation_dates, linear_data, '-o',  observation_dates, exponential_data, '-o')
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/matplotlib/units.py in get_converter(self, x)
    143                 # get_converter
--> 144                 if not np.all(xravel.mask):
    145                     # some elements are not masked

AttributeError: 'numpy.ndarray' object has no attribute 'mask'

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-8-d8577b79c140> in <module>()
      4 observation_dates = np.arange('2017-01-01', '2017-01-09', dtype='datetime64[D]')
      5 observation_dates = map(pd.to_datetime, observation_dates) # trying to plot a map will result in an error
----> 6 plt.plot(observation_dates, linear_data, '-o',  observation_dates, exponential_data, '-o')

~/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py in plot(*args, **kwargs)
   3259                       mplDeprecation)
   3260     try:
-> 3261         ret = ax.plot(*args, **kwargs)
   3262     finally:
   3263         ax._hold = washold

~/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py in inner(ax, *args, **kwargs)
   1715                     warnings.warn(msg % (label_namer, func.__name__),
   1716                                   RuntimeWarning, stacklevel=2)
-> 1717             return func(ax, *args, **kwargs)
   1718         pre_doc = inner.__doc__
   1719         if pre_doc is None:

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_axes.py in plot(self, *args, **kwargs)
   1370         kwargs = cbook.normalize_kwargs(kwargs, _alias_map)
   1371 
-> 1372         for line in self._get_lines(*args, **kwargs):
   1373             self.add_line(line)
   1374             lines.append(line)

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in _grab_next_args(self, *args, **kwargs)
    402                 this += args[0],
    403                 args = args[1:]
--> 404             for seg in self._plot_args(this, kwargs):
    405                 yield seg
    406 

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in _plot_args(self, tup, kwargs)
    382             x, y = index_of(tup[-1])
    383 
--> 384         x, y = self._xy_from_xy(x, y)
    385 
    386         if self.command == 'plot':

~/anaconda3/lib/python3.6/site-packages/matplotlib/axes/_base.py in _xy_from_xy(self, x, y)
    214     def _xy_from_xy(self, x, y):
    215         if self.axes.xaxis is not None and self.axes.yaxis is not None:
--> 216             bx = self.axes.xaxis.update_units(x)
    217             by = self.axes.yaxis.update_units(y)
    218 

~/anaconda3/lib/python3.6/site-packages/matplotlib/axis.py in update_units(self, data)
   1430         """
   1431 
-> 1432         converter = munits.registry.get_converter(data)
   1433         if converter is None:
   1434             return False

~/anaconda3/lib/python3.6/site-packages/matplotlib/units.py in get_converter(self, x)
    155                 if (not isinstance(next_item, np.ndarray) or
    156                     next_item.shape != x.shape):
--> 157                     converter = self.get_converter(next_item)
    158                 return converter
    159 

~/anaconda3/lib/python3.6/site-packages/matplotlib/units.py in get_converter(self, x)
    160         if converter is None:
    161             try:
--> 162                 thisx = safe_first_element(x)
    163             except (TypeError, StopIteration):
    164                 pass

~/anaconda3/lib/python3.6/site-packages/matplotlib/cbook/__init__.py in safe_first_element(obj)
   2309         except TypeError:
   2310             pass
-> 2311         raise RuntimeError("matplotlib does not support generators "
   2312                            "as input")
   2313     return next(iter(obj))

RuntimeError: matplotlib does not support generators as input
In [12]:
plt.figure()
observation_dates = np.arange('2017-01-01', '2017-01-09', dtype='datetime64[D]')
observation_dates = list(map(pd.to_datetime, observation_dates)) # convert the map to a list to get rid of the error
plt.plot(observation_dates, linear_data, '-o',  observation_dates, exponential_data, '-o')
Out[12]:
[<matplotlib.lines.Line2D at 0x7f27986c1eb8>,
 <matplotlib.lines.Line2D at 0x7f2794e640f0>]
In [13]:
x = plt.gca().xaxis

# rotate the tick labels for the x axis
for item in x.get_ticklabels():
    item.set_rotation(45)
In [14]:
# adjust the subplot so the text doesn't run off the image
plt.subplots_adjust(bottom=0.25)
In [15]:
ax = plt.gca()
ax.set_xlabel('Date')
ax.set_ylabel('Units')
ax.set_title('Exponential vs. Linear performance')
Out[15]:
<matplotlib.text.Text at 0x7f2794e3dc88>
In [16]:
# you can add mathematical expressions in any text element
ax.set_title("Exponential ($x^2$) vs. Linear ($x$) performance")
Out[16]:
<matplotlib.text.Text at 0x7f2794e3dc88>

Bar Charts

In [17]:
plt.figure()
xvals = range(len(linear_data))
plt.bar(xvals, linear_data, width = 0.3)
Out[17]:
<Container object of 8 artists>
In [18]:
new_xvals = []

# plot another set of bars, adjusting the new xvals to make up for the first set of bars plotted
for item in xvals:
    new_xvals.append(item+0.3)

plt.bar(new_xvals, exponential_data, width = 0.3 ,color='red')
Out[18]:
<Container object of 8 artists>
In [19]:
from random import randint
linear_err = [randint(0,15) for x in range(len(linear_data))] 

# This will plot a new set of bars with errorbars using the list of random error values
plt.bar(xvals, linear_data, width = 0.3, yerr=linear_err)
Out[19]:
<Container object of 8 artists>
In [20]:
# stacked bar charts are also possible
plt.figure()
xvals = range(len(linear_data))
plt.bar(xvals, linear_data, width = 0.3, color='b')
plt.bar(xvals, exponential_data, width = 0.3, bottom=linear_data, color='r')
Out[20]:
<Container object of 8 artists>
In [21]:
# or use barh for horizontal bar charts
plt.figure()
xvals = range(len(linear_data))
plt.barh(xvals, linear_data, height = 0.3, color='b')
plt.barh(xvals, exponential_data, height = 0.3, left=linear_data, color='r')
Out[21]:
<Container object of 8 artists>
In [22]:
import matplotlib.pyplot as plt
import numpy as np

plt.figure()

languages =['Python', 'SQL', 'Java', 'C++', 'JavaScript']
pos = np.arange(len(languages))
popularity = [56, 39, 34, 34, 29]

plt.bar(pos, popularity, align='center')
plt.xticks(pos, languages)
plt.ylabel('% Popularity')
plt.title('Top 5 Languages for Math & Data \nby % popularity on Stack Overflow', alpha=0.8)

# remove all the ticks (both axes), and tick labels on the Y axis
plt.tick_params(top='off', bottom='off', left='off', right='off', labelleft='off', labelbottom='on')
plt.show()
In [23]:
import matplotlib.pyplot as plt
import numpy as np

plt.figure()

languages =['Python', 'SQL', 'Java', 'C++', 'JavaScript']
pos = np.arange(len(languages))
popularity = [56, 39, 34, 34, 29]

plt.bar(pos, popularity, align='center')
plt.xticks(pos, languages)
plt.ylabel('% Popularity')
plt.title('Top 5 Languages for Math & Data \nby % popularity on Stack Overflow', alpha=0.8)
plt.show()
In [24]:
import matplotlib.pyplot as plt
import numpy as np

plt.figure()

languages =['Python', 'SQL', 'Java', 'C++', 'JavaScript']
pos = np.arange(len(languages))
popularity = [56, 39, 34, 34, 29]

plt.bar(pos, popularity, align='center')
plt.xticks(pos, languages)
plt.ylabel('% Popularity')
plt.title('Top 5 Languages for Math & Data \nby % popularity on Stack Overflow', alpha=0.8)

# remove all the ticks (both axes), and tick labels on the Y axis
plt.tick_params(top='off', bottom='off', left='off', right='off', labelleft='off', labelbottom='on')

# remove the frame of the chart
for spine in plt.gca().spines.values():
    spine.set_visible(False)
plt.show()
In [25]:
import matplotlib.pyplot as plt
import numpy as np

plt.figure()

languages =['Python', 'SQL', 'Java', 'C++', 'JavaScript']
pos = np.arange(len(languages))
popularity = [56, 39, 34, 34, 29]

# change the bar colors to be less bright blue
bars = plt.bar(pos, popularity, align='center', linewidth=0, color='lightslategrey')
# make one bar, the python bar, a contrasting color
bars[0].set_color('#1F77B4')

# soften all labels by turning grey
plt.xticks(pos, languages, alpha=0.8)
plt.ylabel('% Popularity', alpha=0.8)
plt.title('Top 5 Languages for Math & Data \nby % popularity on Stack Overflow', alpha=0.8)

# remove all the ticks (both axes), and tick labels on the Y axis
plt.tick_params(top='off', bottom='off', left='off', right='off', labelleft='off', labelbottom='on')

# remove the frame of the chart
for spine in plt.gca().spines.values():
    spine.set_visible(False)
plt.show()
In [26]:
import matplotlib.pyplot as plt
import numpy as np

plt.figure()

languages =['Python', 'SQL', 'Java', 'C++', 'JavaScript']
pos = np.arange(len(languages))
popularity = [56, 39, 34, 34, 29]

# change the bar colors to be less bright blue
bars = plt.bar(pos, popularity, align='center', linewidth=0, color='lightslategrey')
# make one bar, the python bar, a contrasting color
bars[0].set_color('#1F77B4')

# soften all labels by turning grey
plt.xticks(pos, languages, alpha=0.8)

# TODO: remove the Y label since bars are directly labeled
plt.ylabel('% Popularity', alpha=0.8)
plt.title('Top 5 Languages for Math & Data \nby % popularity on Stack Overflow', alpha=0.8)

# remove all the ticks (both axes), and tick labels on the Y axis
plt.tick_params(top='off', bottom='off', left='off', right='off', labelleft='off', labelbottom='on')

# remove the frame of the chart
for spine in plt.gca().spines.values():
    spine.set_visible(False)
    
# TODO: direct label each bar with Y axis values
plt.show()
In [27]:
import matplotlib.pyplot as plt
import numpy as np

plt.figure()

languages =['Python', 'SQL', 'Java', 'C++', 'JavaScript']
pos = np.arange(len(languages))
popularity = [56, 39, 34, 34, 29]

# change the bar color to be less bright blue
bars = plt.bar(pos, popularity, align='center', linewidth=0, color='lightslategrey')
# make one bar, the python bar, a contrasting color
bars[0].set_color('#1F77B4')

# soften all labels by turning grey
plt.xticks(pos, languages, alpha=0.8)
# remove the Y label since bars are directly labeled
#plt.ylabel('% Popularity', alpha=0.8)
plt.title('Top 5 Languages for Math & Data \nby % popularity on Stack Overflow', alpha=0.8)

# remove all the ticks (both axes), and tick labels on the Y axis
plt.tick_params(top='off', bottom='off', left='off', right='off', labelleft='off', labelbottom='on')

# remove the frame of the chart
for spine in plt.gca().spines.values():
    spine.set_visible(False)
    
# direct label each bar with Y axis values
for bar in bars:
    plt.gca().text(bar.get_x() + bar.get_width()/2, bar.get_height() - 5, str(int(bar.get_height())) + '%', 
                 ha='center', color='w', fontsize=11)
plt.show()