Skip to Content

CHAID v ranger v xgboost - a comparison

In an earlier post, I focused on an in depth visit with CHAID (Chi-square automatic interaction detection). Quoting myself, I said “As the name implies it is fundamentally based on the venerable Chi-square test – and while not the most powerful (in terms of detecting the smallest possible differences) or the fastest, it really is easy to manage and more importantly to tell the story after using it”. In this post I’ll spend a little time comparing CHAID with a random forest algorithm in the ranger library and with a gradient boosting algorithm via the xgboost library. I’ll use the exact same data set for all three so we can draw some easy comparisons about their speed and their accuracy.

I do believe CHAID is a great choice for some sets of data and some circumstances but I’m interested in some empirical information, so off we go.

Setup and library loading

If you’ve never used CHAID before you may also not have partykit. CHAID isn’t on CRAN but I have provided the commented out install command below. ranger and xgboost are available from CRAN and are straightforward to install. You’ll also get a variety of messages, none of which is relevant to this example so I’ve suppressed them.

# install.packages("partykit")
# install.packages("CHAID", repos="http://R-Forge.R-project.org")
# install.packages("ranger")
# install.packages("xgboost")
require(dplyr)
require(tidyr)
require(ggplot2)
require(CHAID)
require(purrr)
require(caret)
require(ranger)
require(xgboost)
require(kableExtra) # just to make the output nicer
theme_set(theme_bw()) # set theme for ggplot2

Predicting customer churn for a fictional TELCO company

We’re going to use a dataset that comes to us from the IBM Watson Project.
It’s a very practical example and an understandable dataset. A great use case for the algorithms we’ll be using. Imagine yourself in a fictional company faced with the task of trying to predict which customers are going to leave your business for another provider a.k.a. churn. Obviously we’d like to be able to predict this phenomenon and potentially target these customers for retention or just better project our revenue. Being able to predict churn even a little bit better could save us lots of money, especially if we can identify the key indicators and influence them.

In the original posting I spent a great deal of time explaining the mechanics of loading and prepping the data. This time we’ll do that quickly and efficiently and if you need an explanation of what’s going on please refer back. I’ve embedded some comments in the code where I think they’ll be most helpful. First we’ll grab the data from the IBM site using read.csv, in this case I’m happy to let it tag most of our variables as factors since that’s what we’ll want for our CHAID work.

set.seed(2018)
churn <- read.csv("WA_Fn-UseC_-Telco-Customer-Churn.csv", 
                  stringsAsFactors = TRUE)
str(churn)
## 'data.frame':    7043 obs. of  21 variables:
##  $ customerID      : Factor w/ 7043 levels "0002-ORFBO","0003-MKNFE",..: 5376 3963 2565 5536 6512 6552 1003 4771 5605 4535 ...
##  $ gender          : Factor w/ 2 levels "Female","Male": 1 2 2 2 1 1 2 1 1 2 ...
##  $ SeniorCitizen   : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ Partner         : Factor w/ 2 levels "No","Yes": 2 1 1 1 1 1 1 1 2 1 ...
##  $ Dependents      : Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 2 1 1 2 ...
##  $ tenure          : int  1 34 2 45 2 8 22 10 28 62 ...
##  $ PhoneService    : Factor w/ 2 levels "No","Yes": 1 2 2 1 2 2 2 1 2 2 ...
##  $ MultipleLines   : Factor w/ 3 levels "No","No phone service",..: 2 1 1 2 1 3 3 2 3 1 ...
##  $ InternetService : Factor w/ 3 levels "DSL","Fiber optic",..: 1 1 1 1 2 2 2 1 2 1 ...
##  $ OnlineSecurity  : Factor w/ 3 levels "No","No internet service",..: 1 3 3 3 1 1 1 3 1 3 ...
##  $ OnlineBackup    : Factor w/ 3 levels "No","No internet service",..: 3 1 3 1 1 1 3 1 1 3 ...
##  $ DeviceProtection: Factor w/ 3 levels "No","No internet service",..: 1 3 1 3 1 3 1 1 3 1 ...
##  $ TechSupport     : Factor w/ 3 levels "No","No internet service",..: 1 1 1 3 1 1 1 1 3 1 ...
##  $ StreamingTV     : Factor w/ 3 levels "No","No internet service",..: 1 1 1 1 1 3 3 1 3 1 ...
##  $ StreamingMovies : Factor w/ 3 levels "No","No internet service",..: 1 1 1 1 1 3 1 1 3 1 ...
##  $ Contract        : Factor w/ 3 levels "Month-to-month",..: 1 2 1 2 1 1 1 1 1 2 ...
##  $ PaperlessBilling: Factor w/ 2 levels "No","Yes": 2 1 2 1 2 2 2 1 2 1 ...
##  $ PaymentMethod   : Factor w/ 4 levels "Bank transfer (automatic)",..: 3 4 4 1 3 3 2 4 3 1 ...
##  $ MonthlyCharges  : num  29.9 57 53.9 42.3 70.7 ...
##  $ TotalCharges    : num  29.9 1889.5 108.2 1840.8 151.7 ...
##  $ Churn           : Factor w/ 2 levels "No","Yes": 1 1 2 1 2 2 1 1 2 1 ...

We have data on 7,043 customers across 21 variables. customerID can’t really be a predictor but we will use it in a little bit. Churn is what we want to predict so we have 19 potential predictor variables to work with. Four of them were not automatically converted to factors so we’ll have to look into them for CHAID. For a review of what the output means and how CHAID works please refer back.

Let’s address the easiest thing first. SeniorCitizen is coded zero and one instead of yes/no so let’s recode that in a nice conservative fashion and see what the breakdown is.

# Fix senior citizen status
churn$SeniorCitizen <- recode_factor(
  churn$SeniorCitizen,
  `0` = "No",
  `1` = "Yes",
  .default = "Should not happen"
)
summary(churn$SeniorCitizen)
##   No  Yes 
## 5901 1142

We have three variables left that are numeric, now that we have addressed senior citizen status. Let’s use a combination of dplyr and ggplot2 to see what the distribution looks like using a density plot.

churn %>%
   select_if(is.numeric) %>%
   gather(metric, value) %>%
   ggplot(aes(value, fill = metric)) +
   geom_density(show.legend = FALSE) +
   facet_wrap( ~ metric, scales = "free")
## Warning: Removed 11 rows containing non-finite values (stat_density).

Well those aren’t the most normal looking distributions and we have this message ## Warning: Removed 11 rows containing non-finite values (stat_density). which alerts us to the fact that there are some missing values in our data. Let’s first figure out where the missing data is:

churn %>%
  select_if(anyNA) %>% summary
##   TotalCharges   
##  Min.   :  18.8  
##  1st Qu.: 401.4  
##  Median :1397.5  
##  Mean   :2283.3  
##  3rd Qu.:3794.7  
##  Max.   :8684.8  
##  NA's   :11

Now we know that total customer charges is missing 11 entries. Our three algorithms vary as to how gracefully they handle missing values but at this point we have several options including:

  • Eliminate the entire customer record if anything is missing
  • Impute or substitute in some reasonable value like the mean or the median for missing values
  • Do some fancier imputation to make sure we substitute in the most plausible value for TotalCharges

Elimination is easy, efficient, and conservative and since it is a very small percentage of our total data set unlikely to cost us a lot of information for the models that don’t handle missing values well. But for purposes of this blog post and to help demonstrate some of the capabilities within caret (since we’re going to use it anyway) we’ll try median and knn (k nearest neighbor) imputation.

First let’s make a vector that contains the customerID numbers of the eleven cases in question.

xxx <- churn %>%
         filter_all(any_vars(is.na(.))) %>% 
         select(customerID)
xxx <- as.vector(xxx$customerID)
xxx
##  [1] "4472-LVYGI" "3115-CZMZD" "5709-LVOEQ" "4367-NUYAO" "1371-DWPAZ" "7644-OMVMY" "3213-VVOLG" "2520-SGTTA" "2923-ARZLG" "4075-WKNIU" "2775-SEFEE"
churn %>% filter(customerID %in% xxx)
##    customerID gender SeniorCitizen Partner Dependents tenure PhoneService    MultipleLines InternetService      OnlineSecurity        OnlineBackup    DeviceProtection         TechSupport         StreamingTV     StreamingMovies Contract PaperlessBilling             PaymentMethod MonthlyCharges TotalCharges Churn
## 1  4472-LVYGI Female            No     Yes        Yes      0           No No phone service             DSL                 Yes                  No                 Yes                 Yes                 Yes                  No Two year              Yes Bank transfer (automatic)          52.55           NA    No
## 2  3115-CZMZD   Male            No      No        Yes      0          Yes               No              No No internet service No internet service No internet service No internet service No internet service No internet service Two year               No              Mailed check          20.25           NA    No
## 3  5709-LVOEQ Female            No     Yes        Yes      0          Yes               No             DSL                 Yes                 Yes                 Yes                  No                 Yes                 Yes Two year               No              Mailed check          80.85           NA    No
## 4  4367-NUYAO   Male            No     Yes        Yes      0          Yes              Yes              No No internet service No internet service No internet service No internet service No internet service No internet service Two year               No              Mailed check          25.75           NA    No
## 5  1371-DWPAZ Female            No     Yes        Yes      0           No No phone service             DSL                 Yes                 Yes                 Yes                 Yes                 Yes                  No Two year               No   Credit card (automatic)          56.05           NA    No
## 6  7644-OMVMY   Male            No     Yes        Yes      0          Yes               No              No No internet service No internet service No internet service No internet service No internet service No internet service Two year               No              Mailed check          19.85           NA    No
## 7  3213-VVOLG   Male            No     Yes        Yes      0          Yes              Yes              No No internet service No internet service No internet service No internet service No internet service No internet service Two year               No              Mailed check          25.35           NA    No
## 8  2520-SGTTA Female            No     Yes        Yes      0          Yes               No              No No internet service No internet service No internet service No internet service No internet service No internet service Two year               No              Mailed check          20.00           NA    No
## 9  2923-ARZLG   Male            No     Yes        Yes      0          Yes               No              No No internet service No internet service No internet service No internet service No internet service No internet service One year              Yes              Mailed check          19.70           NA    No
## 10 4075-WKNIU Female            No     Yes        Yes      0          Yes              Yes             DSL                  No                 Yes                 Yes                 Yes                 Yes                  No Two year               No              Mailed check          73.35           NA    No
## 11 2775-SEFEE   Male            No      No        Yes      0          Yes              Yes             DSL                 Yes                 Yes                  No                 Yes                  No                  No Two year              Yes Bank transfer (automatic)          61.90           NA    No

As you look at those eleven records it doesn’t appear they are “average”! In particular, I’m worried that the MonthlyCharges look small and they have 0 tenure for this group. No way of knowing for certain but it could be that these are just the newest customers with very little time using our service. Let’s use our list to do some comparing of these eleven versus the total population, that will help us decide what to do about the missing cases. Replacing with the median value is simple and easy but it may well not be the most accurate choice.

churn %>% 
   filter(customerID %in% xxx) %>% 
   summarise(median(MonthlyCharges))
##   median(MonthlyCharges)
## 1                  25.75
median(churn$MonthlyCharges, na.rm = TRUE)
## [1] 70.35
churn %>% 
   filter(customerID %in% xxx) %>% 
   summarise(median(tenure))
##   median(tenure)
## 1              0
median(churn$tenure, na.rm = TRUE)
## [1] 29

The median MonthlyCharges are much lower and instead of two years or so of median tenure this group has none. Let’s use the preProcess function in caret to accomplish several goals. We’ll ask it to impute the missing values for us using both knnImpute (k nearest neighbors) and a pure median medianImpute. From the ?preProcess help pages:

k-nearest neighbor imputation is carried out by finding the k closest samples (Euclidian distance) in the training set. Imputation via bagging fits a bagged tree model for each predictor (as a function of all the others). This method is simple, accurate and accepts missing values, but it has much higher computational cost. Imputation via medians takes the median of each predictor in the training set, and uses them to fill missing values. This method is simple, fast, and accepts missing values, but treats each predictor independently, and may be inaccurate.

We’ll also have it transform our numeric variables using YeoJohnson and identify any predictor variables that have near zero variance nzv.

# using k nearest neighbors
pp_knn <- preProcess(churn, method = c("knnImpute", "YeoJohnson", "nzv"))
# simple output
pp_knn
## Created from 7032 samples and 21 variables
## 
## Pre-processing:
##   - centered (3)
##   - ignored (18)
##   - 5 nearest neighbor imputation (3)
##   - scaled (3)
##   - Yeo-Johnson transformation (3)
## 
## Lambda estimates for Yeo-Johnson transformation:
## 0.45, 0.93, 0.25
# more verbose
pp_knn$method
## $knnImpute
## [1] "tenure"         "MonthlyCharges" "TotalCharges"  
## 
## $YeoJohnson
## [1] "tenure"         "MonthlyCharges" "TotalCharges"  
## 
## $ignore
##  [1] "customerID"       "gender"           "SeniorCitizen"    "Partner"          "Dependents"       "PhoneService"     "MultipleLines"    "InternetService"  "OnlineSecurity"   "OnlineBackup"     "DeviceProtection" "TechSupport"      "StreamingTV"      "StreamingMovies"  "Contract"         "PaperlessBilling" "PaymentMethod"    "Churn"           
## 
## $center
## [1] "tenure"         "MonthlyCharges" "TotalCharges"  
## 
## $scale
## [1] "tenure"         "MonthlyCharges" "TotalCharges"
# using medians
pp_median <- preProcess(churn, method = c("medianImpute", "YeoJohnson", "nzv"))
pp_median
## Created from 7032 samples and 21 variables
## 
## Pre-processing:
##   - ignored (18)
##   - median imputation (3)
##   - Yeo-Johnson transformation (3)
## 
## Lambda estimates for Yeo-Johnson transformation:
## 0.45, 0.93, 0.25
pp_median$method
## $medianImpute
## [1] "tenure"         "MonthlyCharges" "TotalCharges"  
## 
## $YeoJohnson
## [1] "tenure"         "MonthlyCharges" "TotalCharges"  
## 
## $ignore
##  [1] "customerID"       "gender"           "SeniorCitizen"    "Partner"          "Dependents"       "PhoneService"     "MultipleLines"    "InternetService"  "OnlineSecurity"   "OnlineBackup"     "DeviceProtection" "TechSupport"      "StreamingTV"      "StreamingMovies"  "Contract"         "PaperlessBilling" "PaymentMethod"    "Churn"

The preProcess function creates a list object of class preProcess that contains information about what needs to be done and what the results of the transformations will be, but we need to apply the predict function to actually make the changes proposed. So at this point let’s create two new dataframes nchurn1 and nchurn2 that contain the data after the pre-processing has occurred. Then we can see how the results compare.

nchurn1 <- predict(pp_knn,churn)
nchurn2 <- predict(pp_median,churn)
nchurn2 %>% 
   filter(customerID %in% xxx) %>% 
   summarise(median(TotalCharges))
##   median(TotalCharges)
## 1             20.79526
median(nchurn2$TotalCharges, na.rm = TRUE)
## [1] 20.79526
nchurn1 %>% 
   filter(customerID %in% xxx) %>% 
   summarise(median(TotalCharges))
##   median(TotalCharges)
## 1            -1.849681
median(nchurn1$TotalCharges, na.rm = TRUE)
## [1] 0.01820494

May also be useful to visualize the data as we did earlier to see how the transformations have changed the density plots.

nchurn1 %>%
  select_if(is.numeric) %>%
  gather(metric, value) %>%
  ggplot(aes(value, fill = metric)) +
      geom_density(show.legend = FALSE) +
      facet_wrap( ~ metric, scales = "free")

nchurn2 %>%
  select_if(is.numeric) %>%
  gather(metric, value) %>%
  ggplot(aes(value, fill = metric)) +
      geom_density(show.legend = FALSE) +
      facet_wrap( ~ metric, scales = "free")

If you compare the two plots you can see that they vary imperceptibly except for the y axis scale. There is no warning about missing values and if you scroll back and compare with the original plots of the raw variables the shape of tenure and TotalCharges have changed significantly because of the transformation.

I’m pretty convinced that knn provides a much better approximation of those eleven missing values than a mere median substitution so let’s make those changes and move on to comparing models. While we’re at it, let’s go ahead and remove the unique customer ID number as well. We really only needed it to compare a few specific cases.

churn <- predict(pp_knn,churn)
churn$customerID <- NULL
str(churn)
## 'data.frame':    7043 obs. of  20 variables:
##  $ gender          : Factor w/ 2 levels "Female","Male": 1 2 2 2 1 1 2 1 1 2 ...
##  $ SeniorCitizen   : Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 1 1 1 1 ...
##  $ Partner         : Factor w/ 2 levels "No","Yes": 2 1 1 1 1 1 1 1 2 1 ...
##  $ Dependents      : Factor w/ 2 levels "No","Yes": 1 1 1 1 1 1 2 1 1 2 ...
##  $ tenure          : num  -1.644 0.297 -1.495 0.646 -1.495 ...
##  $ PhoneService    : Factor w/ 2 levels "No","Yes": 1 2 2 1 2 2 2 1 2 2 ...
##  $ MultipleLines   : Factor w/ 3 levels "No","No phone service",..: 2 1 1 2 1 3 3 2 3 1 ...
##  $ InternetService : Factor w/ 3 levels "DSL","Fiber optic",..: 1 1 1 1 2 2 2 1 2 1 ...
##  $ OnlineSecurity  : Factor w/ 3 levels "No","No internet service",..: 1 3 3 3 1 1 1 3 1 3 ...
##  $ OnlineBackup    : Factor w/ 3 levels "No","No internet service",..: 3 1 3 1 1 1 3 1 1 3 ...
##  $ DeviceProtection: Factor w/ 3 levels "No","No internet service",..: 1 3 1 3 1 3 1 1 3 1 ...
##  $ TechSupport     : Factor w/ 3 levels "No","No internet service",..: 1 1 1 3 1 1 1 1 3 1 ...
##  $ StreamingTV     : Factor w/ 3 levels "No","No internet service",..: 1 1 1 1 1 3 3 1 3 1 ...
##  $ StreamingMovies : Factor w/ 3 levels "No","No internet service",..: 1 1 1 1 1 3 1 1 3 1 ...
##  $ Contract        : Factor w/ 3 levels "Month-to-month",..: 1 2 1 2 1 1 1 1 1 2 ...
##  $ PaperlessBilling: Factor w/ 2 levels "No","Yes": 2 1 2 1 2 2 2 1 2 1 ...
##  $ PaymentMethod   : Factor w/ 4 levels "Bank transfer (automatic)",..: 3 4 4 1 3 3 2 4 3 1 ...
##  $ MonthlyCharges  : num  -1.158 -0.239 -0.343 -0.731 0.214 ...
##  $ TotalCharges    : num  -1.81 0.254 -1.386 0.233 -1.249 ...
##  $ Churn           : Factor w/ 2 levels "No","Yes": 1 1 2 1 2 2 1 1 2 1 ...

One more step before we start using CHAID, ranger, and xgboost and while we have the data in one frame. Let’s take the 3 numeric variables and create 3 analogous variables as factors. This is necessary because CHAID requires categorical a.k.a. nominal data. If you’d like to review the options for how to “cut” the data please refer back to my earlier post.

churn <- churn %>%
   mutate_if(is.numeric, 
             funs(factor = cut_number(., n=5, 
                                      labels = c("Lowest","Below Middle","Middle","Above Middle","Highest"))))
## Warning: `funs()` is deprecated as of dplyr 0.8.0.
## Please use a list of either functions or lambdas: 
## 
##   # Simple named list: 
##   list(mean = mean, median = median)
## 
##   # Auto named with `tibble::lst()`: 
##   tibble::lst(mean, median)
## 
##   # Using lambdas
##   list(~ mean(., trim = .2), ~ median(., na.rm = TRUE))
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
summary(churn)
##     gender     SeniorCitizen Partner    Dependents     tenure        PhoneService          MultipleLines     InternetService             OnlineSecurity              OnlineBackup             DeviceProtection              TechSupport                StreamingTV              StreamingMovies           Contract    PaperlessBilling                   PaymentMethod  MonthlyCharges     TotalCharges       Churn           tenure_factor   MonthlyCharges_factor   TotalCharges_factor
##  Female:3488   No :5901      No :3641   No :4933   Min.   :-1.8439   No : 682     No              :3390   DSL        :2421   No                 :3498   No                 :3088   No                 :3095    No                 :3473   No                 :2810   No                 :2785   Month-to-month:3875   No :2872         Bank transfer (automatic):1544   Min.   :-1.5685   Min.   :-1.929306   No :5174   Lowest      :1481   Lowest      :1420      Lowest      :1409    
##  Male  :3555   Yes:1142      Yes:3402   Yes:2110   1st Qu.:-0.8555   Yes:6361     No phone service: 682   Fiber optic:3096   No internet service:1526   No internet service:1526   No internet service:1526    No internet service:1526   No internet service:1526   No internet service:1526   One year      :1473   Yes:4171         Credit card (automatic)  :1522   1st Qu.:-0.9632   1st Qu.:-0.783551   Yes:1869   Below Middle:1397   Below Middle:1397      Below Middle:1408    
##                                                    Median : 0.1183                Yes             :2971   No         :1526   Yes                :2019   Yes                :2429   Yes                :2422    Yes                :2044   Yes                :2707   Yes                :2732   Two year      :1695                    Electronic check         :2365   Median : 0.2021   Median : 0.018205              Middle      :1408   Middle      :1411      Middle      :1409    
##                                                    Mean   : 0.0000                                                                                                                                                                                                                                                                     Mailed check             :1612   Mean   : 0.0000   Mean   :-0.002732              Above Middle:1350   Above Middle:1407      Above Middle:1408    
##                                                    3rd Qu.: 0.9252                                                                                                                                                                                                                                                                                                      3rd Qu.: 0.8341   3rd Qu.: 0.868066              Highest     :1407   Highest     :1408      Highest     :1409    
##                                                    Max.   : 1.3421                                                                                                                                                                                                                                                                                                      Max.   : 1.7530   Max.   : 1.758003

Okay now we have three additional variables that end in _factor, they’re like their numeric equivalents only cut into more or less 5 equal bins.

Training and testing our models

We’re going to use caret to train and test all three of the algorithms on our data. We could operate directly by invoking the individual model functions directly but caret will allow us to use some common steps. We’ll employ cross-validation a.k.a. cv to mitigate the problem of over-fitting. This article explains it well so I won’t repeat that explanation here, I’ll simply show you how to run the steps in R.

This is also a good time to point out that caret has extraordinarily comprehensive documentation which I used extensively and I’m limiting myself to the basics.

As a first step, let’s just take 30% of our data and put is aside as the testing data set. Why 30%? Doesn’t have to be, could be as low as 20% or as high as 40% it really depends on how conservative you want to be, and how much data you have at hand. Since this is just a tutorial we’ll simply use 30% as a representative number. I’m going to use caret syntax which is the line with createDataPartition(churn$Churn, p=0.7, list=FALSE) in it. That takes our data set churn makes a 70% split ensuring that we keep our outcome variable Churn as close to 70/30 as we can. This is important because our data is already pretty lop-sided for outcomes. The two subsequent lines serve to take the vector intrain and produce two separate dataframes, testing and training. They have 2112 and 4931 customers respectively.

intrain <- createDataPartition(churn$Churn, p=0.7, list=FALSE)
training <- churn[intrain,]
testing <- churn[-intrain,]
dim(training)
## [1] 4931   23
dim(testing)
## [1] 2112   23

CHAID

Now that we have a training and testing dataset let’s remove the numeric version of the variables CHAID can’t use.

# first pass at CHAID
# remove numbers
training <- training %>%
  select_if(is.factor)
dim(training)
## [1] 4931   20
testing <- testing %>%
  select_if(is.factor)
dim(testing)
## [1] 2112   20

The next step is a little counter-intuitive but quite practical. Turns out that many models do not perform well when you feed them a formula for the model even if they claim to support a formula interface (as CHAID does). Here’s a Stack Overflow link that discusses in detail but my suggestion to you is to always separate them and avoid the problem altogether. We’re just taking our predictors or features and putting them in x while we put our outcome in y.

# create response and feature data
features <- setdiff(names(training), "Churn")
x <- training[, features]
y <- training$Churn

trainControl is the next function within caret we need to use. Chapter 5 in the caret doco covers it in great detail. I’m simply going to pluck out a few sane and safe options. method = "cv" gets us cross-validation. number = 5 is pretty obvious. I happen to like seeing the progress in case I want to go for coffee so verboseIter = TRUE (here I will turn it off since the static output is rather boring), and I play it safe and explicitly save my predictions savePredictions = "final". We put everything in train_control which we’ll use in a minute. We’ll use this same train_control for all our models

# set up 5-fold cross validation procedure
train_control <- trainControl(method = "cv",
                              number = 5,
#                              verboseIter = TRUE,
                              savePredictions = "final")

By default caret allows us to adjust three parameters in our chaid model; alpha2, alpha3, and alpha4. As a matter of fact it will allow us to build a grid of those parameters and test all the permutations we like, using the same cross-validation process. I’m a bit worried that we’re not being conservative enough. I’d like to train our model using p values for alpha that are not .05, .03, and .01 but instead the de facto levels in my discipline; .05, .01, and .001. The function in caret is tuneGrid. We’ll use the base R function expand.grid to build a dataframe with all the combinations and then feed it to caret in our training via tuneGrid = search_grid in our call to train.

# set up tuning grid default
search_grid <- expand.grid(
  alpha2 = c(.05, .01, .001),
  alpha4 = c(.05, .01, .001),
  alpha3 = -1
)

Now we can use the train function in caret to train our model! It wants to know what our x and y’s are, as well as our training control parameters which we’ve parked in train_control.

chaid.model <- train(
  x = x,
  y = y,
  method = "chaid",
  trControl = train_control,
  tuneGrid = search_grid
)
chaid.model
## CHi-squared Automated Interaction Detection 
## 
## 4931 samples
##   19 predictor
##    2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3944, 3946, 3945, 3944, 3945 
## Resampling results across tuning parameters:
## 
##   alpha2  alpha4  Accuracy   Kappa    
##   0.001   0.001   0.7903119  0.3910975
##   0.001   0.010   0.7862541  0.3913868
##   0.001   0.050   0.7815888  0.3942768
##   0.010   0.001   0.7903119  0.3910975
##   0.010   0.010   0.7852409  0.3893637
##   0.010   0.050   0.7809788  0.3950494
##   0.050   0.001   0.7886892  0.3938326
##   0.050   0.010   0.7844306  0.3937592
##   0.050   0.050   0.7840230  0.4017554
## 
## Tuning parameter 'alpha3' was held constant at a value of -1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were alpha2 = 0.01, alpha3 = -1 and alpha4 = 0.001.

And after roughly two minutes it’s done. Let’s inspect what we have so far. The output gives us a nice concise summary. 4931 cases with 19 predictors. It gives us an idea of how many of the 4931 cases were used in the individual folds Summary of sample sizes: 3944, 3946, 3945, 3944, 3945. If you need a review of what alpha2, alpha4, and alpha3 are please review the ?chaid doco.

You’ll notice that I stored the results in an object called chaid.model. That object has lots of useful information you can access (it’s a list object of class “train”). As a matter of fact we will be creating one object per run and then using the stored information to build a nice comparison later. For now here are some useful examples of what’s contained in the object…

  1. Produce the confusionMatrix across all folds confusionMatrix(chaid.model)
  2. Plot the effect of the tuning parameters on accuracy plot(chaid.model). Note that the scaling deceives the eye and the results are close across the plot
  3. Check on variable importance varImp(chaid.model)
  4. How long did it take? Look in chaid.model$times

If you need a refresher on what these represent please see the earlier post on CHAID.

confusionMatrix(chaid.model)
## Cross-Validated (5 fold) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction   No  Yes
##        No  67.7 15.2
##        Yes  5.8 11.4
##                             
##  Accuracy (average) : 0.7903
plot(chaid.model)

varImp(chaid.model)
## ROC curve variable importance
## 
##                       Importance
## Contract                100.0000
## tenure_factor            95.3656
## OnlineSecurity           78.1234
## TechSupport              75.7125
## TotalCharges_factor      57.7107
## OnlineBackup             50.5202
## MonthlyCharges_factor    49.9909
## DeviceProtection         44.7066
## PaperlessBilling         43.5889
## Partner                  36.4864
## Dependents               34.7182
## PaymentMethod            25.7785
## SeniorCitizen            24.5579
## StreamingTV              10.6056
## StreamingMovies           8.9569
## MultipleLines             6.5964
## InternetService           3.0121
## gender                    0.3617
## PhoneService              0.0000
chaid.model$times
## $everything
##    user  system elapsed 
## 168.621   4.200 172.899 
## 
## $final
##    user  system elapsed 
##   2.485   0.195   2.681 
## 
## $prediction
## [1] NA NA NA

One of the nice aspects about CHAID as a method is that is relatively easy to “see”" your model in either text or plot format. While there are packages that will help you “see” a random forest; by definition (pardon the pun) it’s hard to see the forest because of all the trees. Simply “printing” the final model with chaid.model$finalModel gives you the text representation while you can plot the final model with plot(chaid.model$finalModel). As I explained in the earlier post it’s nice being able to see where your model fits well and where it misses at a high level.

chaid.model$finalModel
## 
## Model formula:
## .outcome ~ gender + SeniorCitizen + Partner + Dependents + PhoneService + 
##     MultipleLines + InternetService + OnlineSecurity + OnlineBackup + 
##     DeviceProtection + TechSupport + StreamingTV + StreamingMovies + 
##     Contract + PaperlessBilling + PaymentMethod + tenure_factor + 
##     MonthlyCharges_factor + TotalCharges_factor
## 
## Fitted party:
## [1] root
## |   [2] Contract in Month-to-month
## |   |   [3] InternetService in DSL
## |   |   |   [4] TotalCharges_factor in Lowest: No (n = 326, err = 49.7%)
## |   |   |   [5] TotalCharges_factor in Below Middle, Middle, Above Middle, Highest
## |   |   |   |   [6] PhoneService in No: No (n = 146, err = 30.1%)
## |   |   |   |   [7] PhoneService in Yes: No (n = 360, err = 16.4%)
## |   |   [8] InternetService in Fiber optic
## |   |   |   [9] tenure_factor in Lowest
## |   |   |   |   [10] OnlineSecurity in No, No internet service: Yes (n = 409, err = 22.0%)
## |   |   |   |   [11] OnlineSecurity in Yes: No (n = 32, err = 46.9%)
## |   |   |   [12] tenure_factor in Below Middle
## |   |   |   |   [13] MultipleLines in No, No phone service: No (n = 194, err = 49.0%)
## |   |   |   |   [14] MultipleLines in Yes: Yes (n = 223, err = 34.5%)
## |   |   |   [15] tenure_factor in Middle, Above Middle, Highest
## |   |   |   |   [16] OnlineSecurity in No, No internet service
## |   |   |   |   |   [17] PaymentMethod in Bank transfer (automatic), Credit card (automatic), Mailed check: No (n = 223, err = 33.6%)
## |   |   |   |   |   [18] PaymentMethod in Electronic check: Yes (n = 285, err = 49.5%)
## |   |   |   |   [19] OnlineSecurity in Yes: No (n = 142, err = 25.4%)
## |   |   [20] InternetService in No
## |   |   |   [21] TotalCharges_factor in Lowest, Middle, Above Middle, Highest: No (n = 278, err = 20.9%)
## |   |   |   [22] TotalCharges_factor in Below Middle: No (n = 102, err = 4.9%)
## |   [23] Contract in One year
## |   |   [24] StreamingMovies in No, No internet service
## |   |   |   [25] PaymentMethod in Bank transfer (automatic), Credit card (automatic), Mailed check: No (n = 468, err = 3.8%)
## |   |   |   [26] PaymentMethod in Electronic check: No (n = 100, err = 14.0%)
## |   |   [27] StreamingMovies in Yes: No (n = 449, err = 18.5%)
## |   [28] Contract in Two year
## |   |   [29] InternetService in DSL, No: No (n = 886, err = 1.4%)
## |   |   [30] InternetService in Fiber optic: No (n = 308, err = 7.8%)
## 
## Number of inner nodes:    13
## Number of terminal nodes: 17
plot(chaid.model$finalModel)

Finally, probably the most important step of all, we’ll take our trained model and apply it to the testing data that we held back to see how well it fits this data it’s never seen before. This is a key step because it reassures us that we have not overfit (if you want a fuller understanding please consider reading this post on EliteDataScience) our model. We’ll take our model we made with the training dataset chaid.model and have it predict against the testing dataset and see how we did with a confusionMatrix

confusionMatrix(predict(chaid.model, newdata = testing), testing$Churn)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   No  Yes
##        No  1410  325
##        Yes  142  235
##                                           
##                Accuracy : 0.7789          
##                  95% CI : (0.7606, 0.7964)
##     No Information Rate : 0.7348          
##     P-Value [Acc > NIR] : 1.709e-06       
##                                           
##                   Kappa : 0.3664          
##                                           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##             Sensitivity : 0.9085          
##             Specificity : 0.4196          
##          Pos Pred Value : 0.8127          
##          Neg Pred Value : 0.6233          
##              Prevalence : 0.7348          
##          Detection Rate : 0.6676          
##    Detection Prevalence : 0.8215          
##       Balanced Accuracy : 0.6641          
##                                           
##        'Positive' Class : No              
## 

Very nice! Our accuracy on testing actually exceeds the accuracy we achieved in training.

Random Forest via ranger

One of the nicest things about using caret is that it is pretty straight-forward to move from one model to another. The amount of work we have to do while moving from CHAID to ranger and eventually xgboost is actually quite modest.

ranger will accept a mix of factors and numeric variables so our first step will be to go back and recreate training and testing using the numeric versions of tenure, MonthlyCharges, and TotalCharges instead of the _factor versions. intrain still holds our list of rows that should be in training so we’ll follow the exact same process just keep the numeric versions and arrive at x and y to feed to caret and ranger.

##### using ranger
# intrain <- createDataPartition(churn$Churn,p=0.7,list=FALSE)
training <- churn[intrain,]
testing <- churn[-intrain,]
dim(training)
## [1] 4931   23
dim(testing)
## [1] 2112   23
training <- training %>%
  select(-ends_with("_factor"))
dim(training)
## [1] 4931   20
# testing <- testing %>%
#  select(-ends_with("_factor"))
dim(testing)
## [1] 2112   23
# create response and feature data
features <- setdiff(names(training), "Churn")
x <- training[, features]
y <- training$Churn

As I mentioned earlier train_control doesn’t have to change at all. So I’ll just print it to remind you of what’s in there.

search_grid is almost always specific to the model and this is no exception. When we consult the documentation for ranger within caret we see that we can adjust mtry, splitrule, and min.node.size. We’ll put in some reasonable values for those and then put the resulting grid into rf_grid. I tried to give ranger’s search grid about the same amount of flexibility as I did for CHAID.

##### reusing train_control
head(train_control)
## $method
## [1] "cv"
## 
## $number
## [1] 5
## 
## $repeats
## [1] NA
## 
## $search
## [1] "grid"
## 
## $p
## [1] 0.75
## 
## $initialWindow
## NULL
# define a grid of parameter options to try with ranger
rf_grid <- expand.grid(mtry = c(2:4),
                       splitrule = c("gini"),
                       min.node.size = c(3, 5, 7))
rf_grid
##   mtry splitrule min.node.size
## 1    2      gini             3
## 2    3      gini             3
## 3    4      gini             3
## 4    2      gini             5
## 5    3      gini             5
## 6    4      gini             5
## 7    2      gini             7
## 8    3      gini             7
## 9    4      gini             7

Okay, we’re ready to train our model using ranger now. The only additional line we need (besides changing from chaid to ranger is to tell it what to use to capture variable importance e.g. “impurity”.

# re-fit the model with the parameter grid
rf.model <- train(
                  x = x,
                  y = y,
                  method = "ranger",
                  trControl = train_control,
                  tuneGrid = rf_grid,
                  importance = "impurity")
rf.model
## Random Forest 
## 
## 4931 samples
##   19 predictor
##    2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3945, 3945, 3944, 3944, 3946 
## Resampling results across tuning parameters:
## 
##   mtry  min.node.size  Accuracy   Kappa    
##   2     3              0.7969937  0.4313263
##   2     5              0.7967894  0.4284019
##   2     7              0.7986164  0.4326772
##   3     3              0.7939488  0.4297259
##   3     5              0.7923288  0.4239988
##   3     7              0.7929371  0.4236652
##   4     3              0.7907057  0.4209967
##   4     5              0.7925320  0.4249768
##   4     7              0.7935475  0.4257799
## 
## Tuning parameter 'splitrule' was held constant at a value of gini
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were mtry = 2, splitrule = gini and min.node.size = 7.

Now we can run the exact same set of commands as we did with chaid.model on rf.model.

confusionMatrix(rf.model)
## Cross-Validated (5 fold) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction   No  Yes
##        No  67.1 13.8
##        Yes  6.3 12.8
##                             
##  Accuracy (average) : 0.7986
plot(rf.model)

varImp(rf.model)
## ranger variable importance
## 
##                  Overall
## tenure           100.000
## TotalCharges      97.746
## MonthlyCharges    85.711
## Contract          68.963
## OnlineSecurity    43.128
## TechSupport       39.103
## InternetService   28.375
## PaymentMethod     28.361
## OnlineBackup      21.700
## DeviceProtection  17.458
## PaperlessBilling  13.968
## MultipleLines      8.983
## SeniorCitizen      8.652
## Partner            8.181
## StreamingTV        7.428
## Dependents         7.281
## StreamingMovies    7.236
## gender             6.341
## PhoneService       0.000
rf.model$times
## $everything
##    user  system elapsed 
##  86.463   0.637  16.145 
## 
## $final
##    user  system elapsed 
##   1.641   0.011   0.281 
## 
## $prediction
## [1] NA NA NA

Now, the all important prediction against the testing data set.

confusionMatrix(predict(rf.model, newdata = testing), testing$Churn)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   No  Yes
##        No  1421  307
##        Yes  131  253
##                                           
##                Accuracy : 0.7926          
##                  95% CI : (0.7747, 0.8097)
##     No Information Rate : 0.7348          
##     P-Value [Acc > NIR] : 4.068e-10       
##                                           
##                   Kappa : 0.4084          
##                                           
##  Mcnemar's Test P-Value : < 2.2e-16       
##                                           
##             Sensitivity : 0.9156          
##             Specificity : 0.4518          
##          Pos Pred Value : 0.8223          
##          Neg Pred Value : 0.6589          
##              Prevalence : 0.7348          
##          Detection Rate : 0.6728          
##    Detection Prevalence : 0.8182          
##       Balanced Accuracy : 0.6837          
##                                           
##        'Positive' Class : No              
## 

Very nice! Once again our accuracy on testing actually exceeds the accuracy we achieved in training. Looks like we were more accurate than CHAID but we’ll come back to that after we finish xgboost.

Extreme Gradient Boosting via xgboost

Moving from ranger to xgboost is even easier than it was from CHAID.

xgboost like ranger will accept a mix of factors and numeric variables so there is no need to change our training and testing datasets at all. There’s also no need to change our train_control. As far as tuning goes caret supports 7 of the many parameters that you could feed to ?xgboost. If you consult the caret documentation here under xgbTree you’ll see them listed. If you don’t provide any tuning guidance then it will provide a default set of pretty rational initial values. I initially ran it that way but below for purposes of this post have chosen only a few that seem to make the largest difference to accuracy and set the rest to a constant.

One final important note about the code below. Notice in the train command I am feeding a formula Churn ~ . to train. If you try to give it the same x = x & y = y syntax I used with ranger it will fail. That’s because as stated in the doco “xgb.train accepts only an xgb.DMatrix as the input. xgboost, in addition, also accepts matrix, dgCMatrix, or name of a local data file.” You could use commands like xx <- model.matrix(~. -1, data=x)[,-1] & yy <- as.numeric(y) -1 to convert them but since our dataset is small I’m just going to use the formula interface.

# reusing train_control
head(train_control)
## $method
## [1] "cv"
## 
## $number
## [1] 5
## 
## $repeats
## [1] NA
## 
## $search
## [1] "grid"
## 
## $p
## [1] 0.75
## 
## $initialWindow
## NULL
# define a grid of parameter options to try with xgboost
xgb_grid <- expand.grid(nrounds = c(100, 150, 200),
                       max_depth = 1,
                       min_child_weight = 1,
                       subsample = 1,
                       gamma = 0,
                       colsample_bytree = 0.8,
                       eta = c(.2, .3, .4))
xgb_grid
##   nrounds max_depth min_child_weight subsample gamma colsample_bytree eta
## 1     100         1                1         1     0              0.8 0.2
## 2     150         1                1         1     0              0.8 0.2
## 3     200         1                1         1     0              0.8 0.2
## 4     100         1                1         1     0              0.8 0.3
## 5     150         1                1         1     0              0.8 0.3
## 6     200         1                1         1     0              0.8 0.3
## 7     100         1                1         1     0              0.8 0.4
## 8     150         1                1         1     0              0.8 0.4
## 9     200         1                1         1     0              0.8 0.4
# Fit the model with the parameter grid
xgboost.model <- train(Churn ~ ., 
                       training , 
                       method = "xgbTree", 
                       tuneGrid = xgb_grid,
                       trControl = train_control)
xgboost.model
## eXtreme Gradient Boosting 
## 
## 4931 samples
##   19 predictor
##    2 classes: 'No', 'Yes' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 3946, 3945, 3944, 3945, 3944 
## Resampling results across tuning parameters:
## 
##   eta  nrounds  Accuracy   Kappa    
##   0.2  100      0.8012533  0.4418413
##   0.2  150      0.8034846  0.4522553
##   0.2  200      0.8042961  0.4568000
##   0.3  100      0.8028775  0.4531359
##   0.3  150      0.8026749  0.4537737
##   0.3  200      0.8026746  0.4538060
##   0.4  100      0.8049051  0.4615761
##   0.4  150      0.8042974  0.4588367
##   0.4  200      0.8040960  0.4567192
## 
## Tuning parameter 'max_depth' was held constant at a value of 1
## Tuning parameter 'gamma' was held constant at a value of 0
## Tuning parameter 'colsample_bytree' was held constant at a value of 0.8
## Tuning parameter 'min_child_weight' was held constant at a value of 1
## Tuning parameter 'subsample' was held constant at a value of 1
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 100, max_depth = 1, eta = 0.4, gamma = 0, colsample_bytree = 0.8, min_child_weight = 1 and subsample = 1.

After a (relatively) brief moment the results are back. Average accuracy on the training is .8029 which is better than CHAID or ranger. We can run the same additional commands simply by listing xgboost.model.

confusionMatrix(xgboost.model)
## Cross-Validated (5 fold) Confusion Matrix 
## 
## (entries are percentual average cell counts across resamples)
##  
##           Reference
## Prediction   No  Yes
##        No  66.6 12.7
##        Yes  6.8 13.9
##                             
##  Accuracy (average) : 0.8049
plot(xgboost.model)

varImp(xgboost.model)
## xgbTree variable importance
## 
##   only 20 most important variables shown (out of 30)
## 
##                                       Overall
## tenure                               100.0000
## InternetServiceFiber optic            71.1099
## ContractTwo year                      47.7325
## PaymentMethodElectronic check         32.8154
## ContractOne year                      12.5400
## OnlineSecurityNo internet service     11.8672
## PaperlessBillingYes                    7.1530
## InternetServiceNo                      6.6596
## TotalCharges                           5.8051
## OnlineSecurityYes                      5.5126
## MonthlyCharges                         3.3567
## StreamingMoviesYes                     2.9618
## TechSupportYes                         2.2343
## PhoneServiceYes                        2.0734
## MultipleLinesYes                       1.7923
## SeniorCitizenYes                       1.5314
## DependentsYes                          1.1139
## StreamingTVYes                         1.1047
## OnlineBackupYes                        0.4809
## PaymentMethodCredit card (automatic)   0.0000
xgboost.model$times
## $everything
##    user  system elapsed 
##  31.870   0.506   4.578 
## 
## $final
##    user  system elapsed 
##   0.909   0.009   0.126 
## 
## $prediction
## [1] NA NA NA

Now, the all important prediction against the testing data set.

confusionMatrix(predict(xgboost.model, newdata = testing), testing$Churn)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   No  Yes
##        No  1417  279
##        Yes  135  281
##                                           
##                Accuracy : 0.804           
##                  95% CI : (0.7864, 0.8207)
##     No Information Rate : 0.7348          
##     P-Value [Acc > NIR] : 6.831e-14       
##                                           
##                   Kappa : 0.4519          
##                                           
##  Mcnemar's Test P-Value : 2.094e-12       
##                                           
##             Sensitivity : 0.9130          
##             Specificity : 0.5018          
##          Pos Pred Value : 0.8355          
##          Neg Pred Value : 0.6755          
##              Prevalence : 0.7348          
##          Detection Rate : 0.6709          
##    Detection Prevalence : 0.8030          
##       Balanced Accuracy : 0.7074          
##                                           
##        'Positive' Class : No              
## 

Very nice! Once again our accuracy on testing .8063 actually exceeds the accuracy we achieved in training. Looks like we were more accurate than either CHAID or ranger and we’ll focus on the comparison in the next section.

Comparing Models

At this juncture we’re faced with a problem I’ve had before. We’re drowning in data from the individual confusionMatrix results. We’ll resort to the same purrr solution to give us a far more legible table of results focusing on the metrics I’m most interested in. To do that we need to:

  1. Make a named list called modellist that contains our 3 models with a descriptive name for each
  2. Use map from purrr to apply the predict command to each model in turn to our testing dataset
  3. Pipe those results to a second map command to generate a confusion matrix comparing our predictions to testing$Churn which are the actual outcomes.
  4. Pipe those results to a complex map_dfr (that I explained previously) that creates a dataframe of all the results with each model as a row.
  5. Separately grab the elapsed times for training with commands like chaid.model$times$everything[[3]]
  6. Separately grab the best accuracy for training with commands like max(chaid.model$results$Accuracy)
  7. Then use kable to make a pretty table that is much easier to understand.
modellist <- list("CHAID" = chaid.model,
                  "ranger" = rf.model,
                  "xgboost" = xgboost.model)

CompareResults <- map(modellist, ~ predict(.x, newdata = testing)) %>%
                  map(~ confusionMatrix(testing$Churn, .x)) %>%
                  map_dfr(~ cbind(as.data.frame(t(.x$overall)), 
                    as.data.frame(t(.x$byClass))), 
                    .id = "Model")

CompareResults[1,"ETime"] <- chaid.model$times$everything[[3]]
CompareResults[2,"ETime"] <- rf.model$times$everything[[3]]
CompareResults[3,"ETime"] <- xgboost.model$times$everything[[3]]
CompareResults[1,"BestTrain"] <- max(chaid.model$results$Accuracy)
CompareResults[2,"BestTrain"] <- max(rf.model$results$Accuracy)
CompareResults[3,"BestTrain"] <- max(xgboost.model$results$Accuracy)

kable(CompareResults, "html") %>% 
   kable_styling(bootstrap_options = c("striped", "hover", "condensed", "responsive"))
Model Accuracy Kappa AccuracyLower AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue Sensitivity Specificity Pos Pred Value Neg Pred Value Precision Recall F1 Prevalence Detection Rate Detection Prevalence Balanced Accuracy ETime BestTrain
CHAID 0.7788826 0.3664152 0.7605662 0.7964217 0.8214962 0.9999997 0 0.8126801 0.6233422 0.9085052 0.4196429 0.9085052 0.8126801 0.8579252 0.8214962 0.6676136 0.7348485 0.7180111 172.899 0.7903119
ranger 0.7926136 0.4083988 0.7746853 0.8097260 0.8181818 0.9987667 0 0.8223380 0.6588542 0.9155928 0.4517857 0.9155928 0.8223380 0.8664634 0.8181818 0.6728220 0.7348485 0.7405961 16.145 0.7986164
xgboost 0.8039773 0.4519416 0.7863928 0.8207138 0.8030303 0.4694820 0 0.8354953 0.6754808 0.9130155 0.5017857 0.9130155 0.8354953 0.8725369 0.8030303 0.6709280 0.7348485 0.7554880 4.578 0.8049051

What do we know?

Well our table looks very nice but there’s probably still too much information. What data should we focus on and what conclusions can we draw from our little exercise in comparative modeling? I will draw your attention back to this webpage to review the terminology for classification models and how to interpret a confusion matrix.

So Accuracy, Kappa, and F1 are all measures of overall accuracy. There are merits to each. Pos Pred Value, and Neg Pred Value are related but different nuanced ideas we’ll discuss in a minute. We’ll also want to talk about time to complete training our model with ETime and training accuracy with BestTrain.

Let’s use dplyr to select just these columns we want and see what we can glean from this reduced table.

CompareResults %>%
  select(Model, ETime, BestTrain, Accuracy, Kappa, F1, 'Pos Pred Value', 'Neg Pred Value') %>%
  kable("html") %>% 
    kable_styling(bootstrap_options = c("striped", "hover", "condensed", "responsive"))
Model ETime BestTrain Accuracy Kappa F1 Pos Pred Value Neg Pred Value
CHAID 172.899 0.7903119 0.7788826 0.3664152 0.8579252 0.9085052 0.4196429
ranger 16.145 0.7986164 0.7926136 0.4083988 0.8664634 0.9155928 0.4517857
xgboost 4.578 0.8049051 0.8039773 0.4519416 0.8725369 0.9130155 0.5017857

Clearly xgboost is the fastest to train a model, more than 30 times faster than CHAID, and 3 times faster than ranger for this data. Not really surprising since xgboost is a very modern set of code designed from the ground up to be fast and efficient.

One interesting fact you can glean from all 3 models is that they all did better on testing than they did on training. This is slightly unusual since one would expect some differences to be missed but is likely simply due to a lucky split in our data with more of the difficult to predict cases falling in training than testing. The good news is it leaves us feeling comfortable that we did not overfit our model to the training data, which is why we were conservative in our fitting and cross validated the training data.

No matter which “accuracy measure” we look at Accuracy, F1 or Kappa the answer is pretty consistent, xgboost “wins” or is the most accurate. The exception is F1 where ranger edges is out by 0.11775% which means it was correct on about 3 more cases out of 2112 cases in the testing set.

Notice that the differences in accuracy are not large as percentages xgboost is 1.4678% more accurate than CHAID or it correctly predicted 31 more customers. While more accurate is always “better” the practical significance is also a matter of what the stakes are. If a wrong prediction costs you $1,000.00 dollars that additional accuracy is more concerning than a lesser dollar amount.

I also deliberately included Positive and Negative Predictive Values the columns labeled Pos Pred Value and Neg Pred Value for a very specific reason. Notice that CHAID has the highest Pos Pred Value that means is is the most accurate at predicting customers who did not “churn”. Of the 1,552 customers who did not leave us is correctly predicted 1,443 of them. xgboost on the other hand was much much better at Neg Pred Value correctly predicting 298 out of 560 customers who left us. While Accuracy, Kappa and F1 take different approaches to finding “balanced” accuracy sometimes one case negative or positive has more important implications for your business and you should choose those measures.

At least at this point after a possible tl;dr journey we have some empirical data to inform my original statement about CHAID: “As the name implies it is fundamentally based on the venerable Chi-square test – and while not the most powerful (in terms of detecting the smallest possible differences) or the fastest, it really is easy to manage and more importantly to tell the story after using it”.

What don’t we know?

  1. That this example would apply to other types of datasets. Absolutely not! This sort of data is almost ideal for CHAID since it involves a lot of nominal/categorical and/or ordinal data. CHAID will get much slower faster as we add more columns. More generally this was one example relatively small dataset more about learning something about caret and process than a true comparison of accuracy across a wide range of cases.

  2. This is the “best” these models can do with this data Absolutely not! I made no attempt to seriously tune any of them. Tried some mild comparability. Also made no effort to feature engineer or adjust. I’m pretty certain if you tried you can squeeze a little more out of all three. Even wth CHAID there’s more we could do very easily. I arbitrarily divided tenure into 5 equal sized bins. Why not 10? Why not equidistant instead of equal sized?

Done!

I hope you’ve found this useful. I am always open to comments, corrections and suggestions.

Chuck (ibecav at gmail dot com)

comments powered by Disqus