Search This Blog

Machine Learning with R - 1 : Decision Trees using the party package

Fueled by my new found programming love in R, and my general interest in machine learning, I've been looking at various online resources to integrate the two ie, use R to implement popular machine learning and data mining algorithms. Of course, machine learning is a fairly well established field in computer science today, and it has engendered several powerful algorithms - from Bayesian classifiers, to decision trees, to neural networks and the quintessential support vector machines. In this series of blog posts, I'll try to deal with some of these popular algorithms using R. You can access just this series of blog posts by clicking on the "Machine Learning with R" category in the tag cloud on the right.

Thanks to a brilliant open-source community, R has a vibrant developer group, as a result of which it has a pretty cool arsenal of packages that makes our job simpler. Our focus in this series is to utilise those packages and get some data crunching done.

This is the first in the series of posts, and here's what this one is about :

Decision Trees

The decision tree is one of the simplest machine learning algorithms, so it's only logical that we begin our discussion with this. Decision trees are very simple to understand too: given a set of data tuples, we split the set into two depending on a conditional regarding one of the attributes. Then keep splitting each set till be have decently homogenous buckets at the lowest level. (This technique is called 'Hunt's Algorithm')

An example would be to take the entire human population and split them on the basis of gender. Thus, the first conditional could be "Gender = Male? If yes, then put that person in a bucket labelled "M", else put that person in a bucket labeled "F" ". Then, for these buckets formed, we could have conditions like "Age > 30 ? If yes, put in this bucket, else put in that bucket" and so on.

The following image gives a nice example of using a decision tree to classify fruits. It's quite self-explanatory.

As you can see, each conditional need not produce a binary result. For example, "Color" can be "Green", "Yellow" or "Red".

Decision trees are not just used for classifying data points. They are widely used in flow charts and in any form of operations where it is required to make decisions to arrive at a conclusion. The algorithmic creation of a decision tree is simple : a set of logically nested "if... else..." statements (it's easy to see why.) Once, a decision tree model has been built, the data set is fed into the tree and the leaf-level buckets are analysed to check for homogeneity. The model with the highest homogeneity is finally chosen. (For example, instead of splitting by "gender" first, it might be a better option to split by "nationality" for a specific occasion.)

Moving onto R

To make our decision tree we'll be using one of the most famous data sets in the machine learning universe. The iris dataset. This dataset contains 3 classes of 150 instances each, where each class refers to the type of the iris plant. The attributes in this dataset are the sepal length, the sepal width, the petal length, and the petal width. You can see the structure of this dataset by simply typing > str(iris) in the R prompt.

Anyhow, to make a decision tree in R, we need to install the party package. No, it's not a misnomer, because it simply refers to "recursive partytioning" that it helps us to do. Think about what little you know about decision trees, and you'll see that "recursive partytioning" does make a lot of sense.

Now, we split out entire dataset into two parts - the training set and the testing set. This is a very common practice in machine learning - wherein, we train a machine learning algorithm with the training data, and then test our model using the testing data.

To create a split in the iris dataset we first set a seed, and then use the sample function to create a training and testing sample. Screenshot follows. The sample function is quite easy to understand, if you remember a thing or two about undergraduate level probability and statistics. We sample nrow(iris) number of data points into 2 sets. Sampling is done with replacement (replace = TRUE), and the training data is 70% of the whole data, while the testing data is 30%. (that's why the 0.7 and 0.3 values)

Now, we need to build our decision tree. To do that, we first build a formula which we shall be using to depict the dependencies. For this problem, we're trying to build a model that tries to classify a test data point, into one of the three Species classes - ie setosa, virginica, or versicolor. The input is a tuple consisting of sepal width, petal width, sepal length and petal length.

Thus our formula can be -

Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width

Store this formula in a variable and then use the rather convenient ctree function to build the tree model.

As you can see from the screenshot, the ctree function takes as input, the formula and the training data set on which it builds the model data. Once built, you can type in iris_ctree to get a description of the newly built model or type in plot(iris_ctree) to get the tree in all it visual glory.

Now that our model is built, we need to cross-check its validity by pitching it against our test data. So we use the predict function to predict the classes of the test data. And then create a matrix showing the comparison between the prediction result and the actual category. Check screenshot below.

We thus see that the model is quite robust, and has done a fine job at classifying the test set indeed.