A Flask API for serving scikit-learn models

Scikit-learn is an intuitive and powerful Python machine learning library that makes training and validating many models fairly easy. Scikit-learn models can be persisted (pickled) to avoid retraining the model every time they are used. You can use Flask to create an API that can provide predictions based on a set of input variables using a pickled model.

Before we get into Flask it’s important to point out that scikit-learn does not handle categorical variables and missing values. Categorical variables need to be encoded as numeric values. Typically categorical variables are transformed using OneHotEncoder (OHE) or LabelEncoder. LabelEncoder assigns an integer to each categorical value and transforms the original variable to a new variable with corresponding integers replaced for categorical variables. The problem with this approach is that a nominal variable is effectively transformed to an ordinal variable which may fool a model into thinking that the order is meaningful. OHE, on the other hand, does not suffer from this issue, however it tends to explode the number of transformed variables since a new variable is created for every value of a categorical variables.

One thing to know about LabelEncoder is that the transformation will change based on the number of categorical values in a variable. Let’s say you have a “subscription” variable with “gold” and “platinum” values. LabelEncoder will map these to 0 and 1 respectively. Now if you add the value “free” to the mix the assignment is changed (free is encoded as 0, gold to 1, and platinum to 2). For this reason it’s important to keep your original LabelEncoder around for transformation at the prediction time.

For this example I am going to use the titanic dataset. To simplify things further I will only use four variables: age, sex, embarked, and survived.

Sex and Embarked are categorical variables and need to be transformed. “Age” has missing values which is typically imputed, meaning it’s replaced by a summary statistic such as median or mean. Missing values can be quite meaningful and it’s worth investigating what they represent in real-world applications. Here I’m simply going to replace NaNs with 0.

The above snippet will iterate over all columns in df_ and append categorical variables (with data type “O”) to the categoricals list. For non-categorical variables (integers and floats), which is only age in this case, I’m replacing NaNs with zeros. Filling NaNs with a single value may have unintended consequences, especially if the value that you’re replacing NaNs with is within the observed range for the numeric variable. Since zero is not an observed and legitimate age value I’m not introducing bias, I would have if I used 40!

Now we’re ready to OHE our categorical variables. Pandas provides a simple method get_dummies for creating OHE variables for a given dataframe.

df_ohe = pd.get_dummies(df, columns=categoricals, dummy_na=True)

The nice thing about OHE is that it’s deterministic. A new column is created for every column/value combination, in the following column_value format. For instance for the “Embarked” variable we’re going to get “Embarked_C”, “Embarked_Q”, “Embarked_S”, and “Embarked_nan”.

The trained model is ready to be pickled. I’m going to use sklearn’s joblib.

from sklearn.externals import joblibjoblib.dump(clf, 'model.pkl')

That’s it! We have persisted our model. We can load this model into memory in a single line.

clf = joblib.load('model.pkl')

We’re now ready to use Flask to serve our persisted model.

Flask is pretty minimalistic. Here’s what you need to start a bare bones Flask application (on port 8080 in this case).

from flask import Flask

app = Flask(__name__)

if __name__ == '__main__': app.run(port=8080)

We have to do two things: (1) load our persisted model into memory when the application starts, and (2) create an endpoint that takes input variables, transforms them into the appropriate format, and returns predictions.

This would only work under ideal circumstances where the incoming request contains all possible values for the categorical variables. If that’s not the case, get_dummies would generate a dataframe that has less columns than the classifier excepts, which would result in a runtime error. Also numerical variables need to be replaced using the same methodology that we trained the model with.

A solution to the less than expected number of columns is to persist the list of columns from training. Remember that Python objects (including lists and dictionaries) can be pickled. To do this I’m going to use joblib, as I did previously, to dump the list of columns into a pkl file.

This solution is still not foolproof. If you happen to send values that were not seen as a part of the training set, get_dummies will produce extra columns and you’ll run into an error. For this solution to work we need to remove the extra columns that are not a part of model_columns from the query dataframe.