Author: Kyriacos Kyriacou, kyr

Tutorial

Bring your plots to life with Matplotlib animations

Early on in our journey at mlcourse.ai we were taught how to create plots describing our data so that we could gain insights on patterns and relationships that might exist. As the course progressed, we found ourselves constantly creating plots in order to assist us in all kinds of tasks: Whether it was for exploratory data analysis, visualizing results from machine learning (ML) algorithms such as decision trees, or for comparing results using validation curves, plots were prominent everywhere.

In this tutorial, we'll take plotting a step further by learning how to animate them, using a library we're all familiar with, matplotlib.

With the former, an animation is created repeatedly by calling an update function, whereas with the latter frames are precomputed and then animated. We'll be taking a look at both in this tutorial.

When using either of the animation classes, it is imperative to keep a reference to the instance object. The animation is advanced by a timer which the Animation object holds the only reference to. If you don't hold that reference, the object will be garbage collected which stops the animation. This is inconsistent with other plotting functions of matplotlib, where you can use plt.plot() without storing the result, so this is an important point to remember.

Let's look at some code for animating the cosine function, and then we'll dive into explaining each function and the FuncAnimation API.

In [11]:

importnumpyasnpimportmatplotlib.pyplotaspltimportseabornassnssns.set(rc={'figure.figsize':(10,7)})importwarningswarnings.filterwarnings('ignore')frommatplotlib.animationimportFuncAnimationfrommatplotlib.animationimportArtistAnimation# to display animation in notebookfromIPython.displayimportHTML

In [12]:

fig,ax=plt.subplots()ax.set_xlim((0,2))ax.set_ylim((-2,2))ax.grid(True)# create a line which we'll animateline,=plt.plot([],[],lw=2)# function used to draw a clear framedefinit():line.set_data([],[])returnline,# called sequentialy on each framedefupdate(frame):x=np.linspace(0,2,1000)y=np.cos(2*np.pi*(x-0.01*frame))line.set_data(x,y)returnline,# hide static plotplt.close()ani=FuncAnimation(fig,update,frames=100,interval=20,init_func=init,blit=True)HTML(ani.to_html5_video())# ani.save('cosine_function.mp4', fps=30)

First, we create a figure and axes, and then set the limits of the y and x coordinates. Then, we create a line (with no points), which we'll be animating in our plot. If we wanted to animate another line, we'd also create it here.

definit():line.set_data([],[])returnline,

Here, we define the init function which is the function that is used to draw a clear frame. The function must return a touple of all objects that will be animated, in this case it's a single line so we return a touple of length 1.

This is the update function, the guts of the animation. This function is called sequentially in each frame of the animation. It is in this function that we calculate the new coordinates of the object that we are animating.

The frame parameter is the current frame number. The update function can receive additional parameters using the fargs parameter in FuncAnimation method.

Finally, we call the FuncAnimation method using the parameters explained earlier. As you can see, we are saving a reference to the animation using the ani parameter. If we didn't, then no animation would take place.

Using IPython.display.HTML, we display the animation in the jupyter notebook.

Alternatively, we could have saved the animation to a file using:

ani.save('cosine_function.mp4', fps=30)

If you have problems saving the animation, ensure you have ffmpeg or menencoder installed.

From now on, instead of having to call HTML each time, we will set a parameter so that simply calling the animation object will display the animation in the notebook.

Now that we're familiar with animating a single object, let's go further by animating multiple objects. Let's add a sine line and display the frame number:

In [14]:

fig,ax=plt.subplots()ax.set_xlim((0,2))ax.set_ylim((-2,2))ax.grid(True)frame_text=ax.text(0.05,1.9,'',bbox=dict(facecolor='white',alpha=1))cos_line,=plt.plot([],[],lw=2,c='b',label='cos')sin_line,=plt.plot([],[],lw=2,c='r',label='sin')plt.legend()# function used to draw a clear framedefinit():cos_line.set_data([],[])sin_line.set_data([],[])frame_text.set_text('')returncos_line,sin_line,frame_text# called sequentialy on each framedefupdate(frame):x=np.linspace(0,2,1000)cos_x=np.cos(2*np.pi*(x-0.01*frame))sin_x=np.sin(2*np.pi*(x-0.01*frame))cos_line.set_data(x,cos_x)sin_line.set_data(x,sin_x)frame_text.set_text('frame = %d'%frame)returncos_line,sin_line,frame_text# hide static plotplt.close()ani=FuncAnimation(fig,update,frames=100,interval=20,init_func=init,blit=True)ani

Even though I'm sure looking at a cos/sin animated plot is a lot of fun, let's create some animations which can be useful for us in our ML journey.

In Topic 7, we learned about clustering algorithms, and in particular K-Means. However, we were only able to plot the final clusters after all iterations had completed. It would be more interesting to see how clusters form and change throughout each iteration. Sounds like an animation is perfect for this!

We'll be using the iris dataset for this task. Since the iris dataset consists of three features, let's use PCA to reduce that down to 2, so we can plot it in 2 dimensions, just as we did in Topic 7:

Unfortunately, sklearn does not provide a way to view historical data of centroids - that is, we can only see the final centroid locations after n iterations, but we want to see the centrois at each iteration. There are two ways around this, we can either create our own implementation of KMeans which stores and returns centroid location after each iteration or with a little ingenuity we can use KMeans from sklearn to get the centrois on each iteration.

I'll initialize the centroids myself, and then consecutively call KMeans with a fixed random random state and initial centroids, and at each iteration I'll increment the max_iter parameter by 1 and fit to the data. This way, on each iteration of our loop, we can look at the centroids. It'll make more sense when u read the code.

First, create initial centroids and place them in init array. Then, loop over the number of iterations we want to run KMeans for.

In each iteration, we perform KMeans up to that iteration, and obtain the cluster predictions and the new centroids for that iteration. Update each new cluster using plt.plot, and draw new centroid locations with plt.scatter. All objects returned from plt.text, plt.scatter and plt.plot are artists which must be appended to the ims array in a list.

Matplotlib has many examples of animations here. I'll show one below which displays 3-D capabilities:

In [20]:

# taken from https://matplotlib.org/gallery/animation/random_walk.htmlimportmpl_toolkits.mplot3d.axes3dasp3# Fixing random state for reproducibilitynp.random.seed(19680801)defGen_RandLine(length,dims=2):""" Create a line using a random walk algorithm length is the number of points for the line. dims is the number of dimensions the line has. """lineData=np.empty((dims,length))lineData[:,0]=np.random.rand(dims)forindexinrange(1,length):# scaling the random numbers by 0.1 so# movement is small compared to position.# subtraction by 0.5 is to change the range to [-0.5, 0.5]# to allow a line to move backwards.step=((np.random.rand(dims)-0.5)*0.1)lineData[:,index]=lineData[:,index-1]+stepreturnlineDatadefupdate_lines(num,dataLines,lines):forline,datainzip(lines,dataLines):# NOTE: there is no .set_data() for 3 dim data...line.set_data(data[0:2,:num])line.set_3d_properties(data[2,:num])returnlines# Attaching 3D axis to the figurefig=plt.figure()ax=p3.Axes3D(fig)# Fifty lines of random 3-D linesdata=[Gen_RandLine(25,3)forindexinrange(50)]# Creating fifty line objects.# NOTE: Can't pass empty arrays into 3d version of plot()lines=[ax.plot(dat[0,0:1],dat[1,0:1],dat[2,0:1])[0]fordatindata]# Setting the axes propertiesax.set_xlim3d([0.0,1.0])ax.set_xlabel('X')ax.set_ylim3d([0.0,1.0])ax.set_ylabel('Y')ax.set_zlim3d([0.0,1.0])ax.set_zlabel('Z')ax.set_title('3D Test')plt.close()# Creating the Animation objectani=FuncAnimation(fig,update_lines,25,fargs=(data,lines),interval=50,blit=False)ani

Matplotlib provides great capability for animations straight out of the box! We saw that through usage of the FuncAnimation or the ArtistAnimation class, we have a lot of power in our hands for producing animations. We can plot mathematical functions, clusters and decision boundaries from our ML algorithms and even animate 3-D plots.

Animations can be used to help us understand how our ML algorithms are workings, or they can be used to produce awesome animations to show off in your next blog post. Either way, the possibilites are endless.