Can we predict flu deaths with Machine Learning and R?

Among the many R packages, there is the outbreaks package. It contains datasets on epidemics, on of which is from the 2013 outbreak of influenza A H7N9 in China, as analysed by Kucharski et al. (2014):

I will be using their data as an example to test whether we can use Machine Learning algorithms for predicting disease outcome.

To do so, I selected and extracted features from the raw data, including age, days between onset and outcome, gender, whether the patients were hospitalised, etc. Missing values were imputed and different model algorithms were used to predict outcome (death or recovery). The prediction accuracy, sensitivity and specificity.
The thus prepared dataset was devided into training and testing subsets. The test subset contained all cases with an unknown outcome. Before I applied the models to the test data, I further split the training data into validation subsets.

The tested modeling algorithms were similarly successful at predicting the outcomes of the validation data. To decide on final classifications, I compared predictions from all models and defined the outcome “Death” or “Recovery” as a function of all models, whereas classifications with a low prediction probability were flagged as “uncertain”. Accounting for this uncertainty led to a 100% correct classification of the validation test set.

The training cases with unknown outcome were then classified based on the same algorithms. From 57 unknown cases, 14 were classified as “Recovery”, 10 as “Death” and 33 as uncertain.

In a Part 2 I am looking at how extreme gradient boosting performs on this dataset.

The data

The dataset contains case ID, date of onset, date of hospitalisation, date of outcome, gender, age, province and of course the outcome: Death or Recovery.
I can already see that there are a couple of missing values in the data, which I will deal with later.

# install and load packageif(!require("outbreaks"))install.packages("outbreaks")library(outbreaks)fluH7N9.china.2013_backup<-fluH7N9.china.2013# back up original dataset in case something goes awry along the way# convert ? to NAsfluH7N9.china.2013$age[which(fluH7N9.china.2013$age=="?")]<-NA# create a new column with case IDfluH7N9.china.2013$case.ID<-paste("case",fluH7N9.china.2013$case.ID,sep="_")head(fluH7N9.china.2013)

Before I start preparing the data for Machine Learning, I want to get an idea of the distribution of the data points and their different variables by plotting.

Most provinces have only a handful of cases, so I am combining them into the category “other” and keep only Jiangsu, Shanghai and Zhejian and separate provinces.

# gather for plotting with ggplot2library(tidyr)fluH7N9.china.2013_gather<-fluH7N9.china.2013%>%gather(Group,Date,date.of.onset:date.of.outcome)# rearrange group orderfluH7N9.china.2013_gather$Group<-factor(fluH7N9.china.2013_gather$Group,levels=c("date.of.onset","date.of.hospitalisation","date.of.outcome"))# rename groupslibrary(plyr)fluH7N9.china.2013_gather$Group<-mapvalues(fluH7N9.china.2013_gather$Group,from=c("date.of.onset","date.of.hospitalisation","date.of.outcome"),to=c("Date of onset","Date of hospitalisation","Date of outcome"))# renaming provincesfluH7N9.china.2013_gather$province<-mapvalues(fluH7N9.china.2013_gather$province,from=c("Anhui","Beijing","Fujian","Guangdong","Hebei","Henan","Hunan","Jiangxi","Shandong","Taiwan"),to=rep("Other",10))# add a level for unknown genderlevels(fluH7N9.china.2013_gather$gender)<-c(levels(fluH7N9.china.2013_gather$gender),"unknown")fluH7N9.china.2013_gather$gender[is.na(fluH7N9.china.2013_gather$gender)]<-"unknown"# rearrange province order so that Other is the lastfluH7N9.china.2013_gather$province<-factor(fluH7N9.china.2013_gather$province,levels=c("Jiangsu","Shanghai","Zhejiang","Other"))# convert age to numericfluH7N9.china.2013_gather$age<-as.numeric(as.character(fluH7N9.china.2013_gather$age))

In the dataset, there are more male than female cases and correspondingly, we see more deaths, recoveries and unknown outcomes in men than in women. This is potentially a problem later on for modeling because the inherent likelihoods for outcome are not directly comparable between the sexes.

Most unknown outcomes were recorded in Zhejiang. Similarly to gender, we don’t have an equal distribution of data points across provinces either.

When we look at the age distribution it is obvious that people who died tended to be slightly older than those who recovered. The density curve of unknown outcomes is more similar to that of death than of recovery, suggesting that among these people there might have been more deaths than recoveries.

And lastly, I want to plot how many days passed between onset, hospitalisation and outcome for each case.

ggplot(data=fluH7N9.china.2013_gather,aes(x=Date,y=age,color=outcome))+geom_point(aes(shape=gender),size=1.5,alpha=0.6)+geom_path(aes(group=case.ID))+facet_wrap(~province,ncol=2)+my_theme()+scale_shape_manual(values=c(15,16,17))+scale_color_brewer(palette="Set1",na.value="grey50")+scale_fill_brewer(palette="Set1")+labs(color="Outcome",shape="Gender",x="Date in 2013",y="Age",title="2013 Influenza A H7N9 cases in China",subtitle="Dataset from 'outbreaks' package (Kucharski et al. 2014)",caption="\nTime from onset of flu to outcome.")

This plot shows that there are many missing values in the dates, so it is hard to draw a general conclusion.

Features

In Machine Learning-speak features are the variables used for model training. Using the right features dramatically influences the accuracy of the model.

Because we don’t have many features, I am keeping age as it is, but I am also generating new features:

from the date information I am calculating the days between onset and outcome and between onset and hospitalisation

I am converting gender into numeric values with 1 for female and 0 for male

Imputing missing values

When looking at the dataset I created for modeling, it is obvious that we have quite a few missing values.

The missing values from the outcome column are what I want to predict but for the rest I would either have to remove the entire row from the data or impute the missing information. I decided to try the latter with the mice package.

This randomly generated decision tree shows that cases with an early outcome were most likely to die when they were 68 or older, when they also had an early onset and when they were sick for fewer than 13 days. If a person was not among the first cases and was younger than 52, they had a good chance of recovering, but if they were 82 or older, they were more likely to die from the flu.

Feature Importance

Not all of the features I created will be equally important to the model. The decision tree already gave me an idea of which features might be most important but I also want to estimate feature importance using a Random Forest approach with repeated cross validation.

ggplot(data=dataset_complete_gather,aes(x=as.numeric(value),fill=outcome,color=outcome))+geom_density(alpha=0.2)+geom_rug()+scale_color_brewer(palette="Set1",na.value="grey50")+scale_fill_brewer(palette="Set1",na.value="grey50")+my_theme()+facet_wrap(set~group,ncol=11,scales="free")+labs(x="Value",y="Density",title="2013 Influenza A H7N9 cases in China",subtitle="Features for classifying outcome",caption="\nDensity distribution of all features used for classification of flu outcome.")

ggplot(subset(dataset_complete_gather,group=="Age"|group=="Days onset to hospital"|group=="Days onset to outcome"),aes(x=outcome,y=as.numeric(value),fill=set))+geom_boxplot()+my_theme()+scale_fill_brewer(palette="Set1",type="div ")+facet_wrap(~group,ncol=3,scales="free")+labs(fill="",x="Outcome",y="Value",title="2013 Influenza A H7N9 cases in China",subtitle="Features for classifying outcome",caption="\nBoxplot of the features age, days from onset to hospitalisation and days from onset to outcome.")

Luckily, the distributions looks reasonably similar between the validation and test data, except for a few outliers.

Comparing Machine Learning algorithms

Before I try to predict the outcome of the unknown cases, I am testing the models’ accuracy with the validation datasets on a couple of algorithms. I have chosen only a few more well known algorithms, but caret implements many more.

I have chose to not do any preprocessing because I was worried that the different data distributions with continuous variables (e.g. age) and binary variables (i.e. 0, 1 classification of e.g. hospitalisation) would lead to problems.

Random Forest

Random Forests predictions are based on the generation of multiple classification trees.

Comparing accuracy of models

All models were similarly accurate.

# Create a list of modelsmodels<-list(rf=model_rf,glmnet=model_glmnet,kknn=model_kknn,pda=model_pda,slda=model_slda,pam=model_pam,C5.0Tree=model_C5.0Tree,pls=model_pls)# Resample the modelsresample_results<-resamples(models)# Generate a summarysummary(resample_results,metric=c("Kappa","Accuracy"))

Combined results of predicting validation test samples

To compare the predictions from all models, I summed up the prediction probabilities for Death and Recovery from all models and calculated the log2 of the ratio between the summed probabilities for Recovery by the summed probabilities for Death. All cases with a log2 ratio bigger than 1.5 were defined as Recover, cases with a log2 ratio below -1.5 were defined as Death, and the remaining cases were defined as uncertain.

results_combined_gather<-results_combined%>%gather(group_dates,date,date.of.onset:date.of.hospitalisation)results_combined_gather$group_dates<-factor(results_combined_gather$group_dates,levels=c("date.of.onset","date.of.hospitalisation"))results_combined_gather$group_dates<-mapvalues(results_combined_gather$group_dates,from=c("date.of.onset","date.of.hospitalisation"),to=c("Date of onset","Date of hospitalisation"))results_combined_gather$gender<-mapvalues(results_combined_gather$gender,from=c("f","m"),to=c("Female","Male"))levels(results_combined_gather$gender)<-c(levels(results_combined_gather$gender),"unknown")results_combined_gather$gender[is.na(results_combined_gather$gender)]<-"unknown"results_combined_gather$age<-as.numeric(as.character(results_combined_gather$age))

The comparison of date of onset, data of hospitalisation, gender and age with predicted outcome shows that predicted deaths were associated with older age than predicted Recoveries. Date of onset does not show an obvious bias in either direction.

Conclusions

This dataset posed a couple of difficulties to begin with, like unequal distribution of data points across variables and missing data. This makes the modeling inherently prone to flaws. However, real life data isn’t perfect either, so I went ahead and tested the modeling success anyway.

By accounting for uncertain classification with low predictions probability, the validation data could be classified accurately. However, for a more accurate model, these few cases don’t give enough information to reliably predict the outcome. More cases, more information (i.e. more features) and fewer missing data would improve the modeling outcome.

Also, this example is only applicable for this specific case of flu. In order to be able to draw more general conclusions about flu outcome, other cases and additional information, for example on medical parameters like preexisting medical conditions, disase parameters, demographic information, etc. would be necessary.

All in all, this dataset served as a nice example of the possibilities (and pitfalls) of machine learning applications and showcases a basic workflow for building prediction models with R.