Do TidyModels float? : Predicting Passenger Survival on the Titanic

Using Tidymodels to Analyze the Data and Predict Who Survived Titanic Tragedy

Karat Sidhu
12 min readFeb 2, 2023
Photo by NOAA on Unsplash

The Titanic dataset is a popular choice among data scientists and machine learning practitioners, containing information about the passengers on the ill-fated voyage, such as their demographics, ticket class, and survival status. Using various classification algorithms and the tidymodels framework, I attempt to predict passenger survival on the Titanic.

Introduction

About the data

The test dataset contains following information

pclass: A proxy for socio-economic status (SES)
- 1st = Upper
- 2nd = Middle
- 3rd = Lower

age: Age is fractional if less than 1. If the age is estimated, is it in the form of xx.5

sibsp: The dataset defines family relations in this way…
Sibling = brother, sister, stepbrother, stepsister
Spouse = husband, wife (mistresses and fiancés were ignored)

parch: The dataset defines family relations in this way…
- Parent = mother, father
- Child = daughter, son, stepdaughter, stepson
Some children traveled only with a nanny, therefore parch=0 for them.

Loading Packages

Loading the necessary packages for the analysis. The tidymodels package is used to build the classification models, and the discrim package is used for the linear discriminant analysis. The skimr package is used to provide a summary of the dataset. The knitr package is used to generate the report, and for the kable function to display the dataset. The tidyverse package is used for data wrangling and visualization, and the hrbrthemes package is used for the theme of the plots to make them more aesthetically pleasing.

suppressPackageStartupMessages({
library(tidyverse)
library(skimr)
library(knitr)
library(tidymodels)
library(discrim)
})

Exploratory Data Analysis

Kaggle provides a training dataset and a testing dataset. The training dataset is used to build the classification models, and the testing dataset is used to evaluate the performance of the models. The training dataset contains 891 observations and 12 attributes, while the testing dataset contains 418 observations and 11 attributes. The testing dataset is missing the survival attribute, which is the target variable for the analysis.

Loading the data

training_data <- read_csv("data/train.csv")
Rows: 891 Columns: 12
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr (5): Name, Sex, Ticket, Cabin, Embarked
dbl (7): PassengerId, Survived, Pclass, Age, SibSp, Parch, Fare
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
testing_data <- read_csv("data/test.csv")
Rows: 418 Columns: 11
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr (5): Name, Sex, Ticket, Cabin, Embarked
dbl (6): PassengerId, Pclass, Age, SibSp, Parch, Fare
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.

Lets take a look at the first few rows of the training dataset.

kable(training_data |> head(5))
Brief Overview of Data — Image by Author

Using the skimr package, we can get a summary of the dataset and the attributes.

skim(training_data)
Data Summary using SkimR package — Image by Author

The dataset contains various attributes that pertain to the passengers on board the ship at the time of the sinking. The objective of this analysis is to identify which attributes are most relevant in determining passenger survival.

Upon initial inspection, certain attributes appear to be less informative in predicting passenger survival. For instance, the passenger ID attribute is arbitrary and does not provide any insight into the outcome, and thus it can be removed from the analysis. Similarly, the ticket number attribute does not seem to hold any significance in relation to passenger survival, thus it could also be removed. However, the name attribute could potentially provide information on demographics such as sex, nobility, or title, and thus further analysis may be necessary before deciding to remove this attribute.

Furthermore, the dataset contains several attributes with missing values, such as the Cabin number and Age attributes. These attributes need to be cleaned and preprocessed before use in model building. As the classification models tend to perform better with factor attributes, certain attributes such as Pclass, SibSp, Fare, Embarked and Parch should be converted to factors as necessary.

Data Wrangling & Conversion

The column representing passenger age in the dataset contains a significant amount of missing values. While utilizing more advanced methods for imputing these values, such as regression or interpolation, would likely yield more accurate results, for the purposes of this analysis and submission, the mean value for age will be utilized as a surrogate for the missing data.

Combining the training and testing data sets will allow for a more robust analysis of the data, as well as a more accurate imputation of missing values. This will also allow for the removal of the “Name” attribute, which will be done in the next step. The “PassengerId” and “Ticket” attributes will also be removed, as they do not provide any useful information for the analysis. The “Cabin” attribute will also be removed for now, as it contains a significant amount of missing values, and further analysis will be required to determine if this attribute can be utilized in the analysis.

training_data <- training_data |> 
mutate(group = 'training')
testing_data <- testing_data |>
mutate(group = 'testing')

combined_data <- bind_rows(training_data, testing_data)
combined_data <- combined_data |> select(-c(Name, PassengerId, Ticket,Cabin))
glimpse(combined_data)
Rows: 1,309
Columns: 9
$ Survived <dbl> 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0…
$ Pclass <dbl> 3, 1, 3, 1, 3, 3, 1, 3, 3, 2, 3, 1, 3, 3, 3, 2, 3, 2, 3, 3, 2…
$ Sex <chr> "male", "female", "female", "female", "male", "male", "male",…
$ Age <dbl> 22, 38, 26, 35, 35, NA, 54, 2, 27, 14, 4, 58, 20, 39, 14, 55,…
$ SibSp <dbl> 1, 1, 0, 1, 0, 0, 0, 3, 0, 1, 1, 0, 0, 1, 0, 0, 4, 0, 1, 0, 0…
$ Parch <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 1, 0, 0, 5, 0, 0, 1, 0, 0, 0, 0…
$ Fare <dbl> 7.2500, 71.2833, 7.9250, 53.1000, 8.0500, 8.4583, 51.8625, 21…
$ Embarked <chr> "S", "C", "S", "S", "S", "Q", "S", "S", "S", "C", "S", "S", "…
$ group <chr> "training", "training", "training", "training", "training", "…

Next we will impute the missing values in the “Age” attribute. The following code will replace the missing values in the “Age” attribute with the mean value for the attribute.

combined_data$Age[is.na(combined_data$Age)] <- mean(combined_data$Age,na.rm=T)
sum(is.na(combined_data$Age))
[1] 0

The next step in the analysis is to convert the variables “survived”, “sex”, “SibSp”, and “Embarked” from their original data types to factors. This will facilitate the manipulation and analysis of these categorical variables.

combined_data <- combined_data |> 
mutate(Survived = as.factor(Survived),
Sex = as.factor(Sex),
Embarked = as.factor(Embarked))
glimpse(combined_data)
glimpse(combined_data)
Rows: 1,309
Columns: 9
$ Survived <fct> 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0…
$ Pclass <dbl> 3, 1, 3, 1, 3, 3, 1, 3, 3, 2, 3, 1, 3, 3, 3, 2, 3, 2, 3, 3, 2…
$ Sex <fct> male, female, female, female, male, male, male, male, female,…
$ Age <dbl> 22.00000, 38.00000, 26.00000, 35.00000, 35.00000, 29.88114, 5…
$ SibSp <dbl> 1, 1, 0, 1, 0, 0, 0, 3, 0, 1, 1, 0, 0, 1, 0, 0, 4, 0, 1, 0, 0…
$ Parch <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 1, 0, 0, 5, 0, 0, 1, 0, 0, 0, 0…
$ Fare <dbl> 7.2500, 71.2833, 7.9250, 53.1000, 8.0500, 8.4583, 51.8625, 21…
$ Embarked <fct> S, C, S, S, S, Q, S, S, S, C, S, S, S, S, S, S, Q, S, S, C, S…
$ group <chr> "training", "training", "training", "training", "training", "…

Important

Name column could potentially contain some useful information that can help in figuring out the survival rate, but for this round, I prefer not to use it.

So far, I am done with the data pre-processing and ‘feature engineering’ steps and am ready to build the models I would like to try out for this data.

Model Building

Using TidyModels package to build the models I would like to try out for this data. Since we are primarily looking at whether passengers in titanic survived, the data is categorical, therefore, classifications models are useful. The main algorithms I would like to try are:

  • Logistical Regression,
  • Random Forest Classification models,
  • K-Nearest Neighbors classifications,
  • XGBoost,
  • Decision Trees, and
  • Linear Discriminant Analysis.

There are several other algorithms that could be useful but for the sake of simplicity, the aforementioned models will be checked against the data.

Splitting the Data into Test and Train

The purpose of this is to evaluate the performance of the model that is trained on the training set, by testing it on the unseen data which is the testing set. This allows for a more accurate assessment of the model’s ability to generalize to new, unseen data. Additionally, by using a separate testing set, we can prevent overfitting, which occurs when a model is trained too well on the training set and performs poorly on new data. Since we already have the training set, for simplicity sake, we can just use that as the training split.

train <- 
combined_data |>
filter(group == "training") |>
mutate(
Survived = as.factor(Survived)
) |> select(-group)
set.seed(123)
split_data <- initial_split(train)
titanic_train <- training(split_data)
titanic_test <- testing(split_data)

set.seed(234)
titanic_folds <- vfold_cv(titanic_train, strata = Survived)
titanic_folds
#  10-fold cross-validation using stratification 
# A tibble: 10 × 2
splits id
<list> <chr>
1 <split [600/68]> Fold01
2 <split [600/68]> Fold02
3 <split [601/67]> Fold03
4 <split [601/67]> Fold04
5 <split [601/67]> Fold05
6 <split [601/67]> Fold06
7 <split [602/66]> Fold07
8 <split [602/66]> Fold08
9 <split [602/66]> Fold09
10 <split [602/66]> Fold10
titanic_formula <- Survived ~ .

Logistical Regression

glm_spec <-
logistic_reg() |>
set_engine("glm")
glm_wf <- workflow(titanic_formula, glm_spec)
contrl_preds <- control_resamples(save_pred = TRUE)
glm_rs <- fit_resamples(
glm_wf,
resamples = titanic_folds,
control = contrl_preds
)
collect_metrics(glm_rs)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.815 10 0.0141 Preprocessor1_Model1
2 roc_auc binary 0.855 10 0.0180 Preprocessor1_Model1

K Nearest Neighbors

knn_spec <-
nearest_neighbor(mode = "classification", weight_func = "rectangular", neighbors = 5) |>
set_engine("kknn")
knn_wf <- workflow(titanic_formula, knn_spec)
knn_rs <- fit_resamples(
knn_wf,
resamples = titanic_folds,
control = contrl_preds
)
collect_metrics(knn_rs)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.810 10 0.0121 Preprocessor1_Model1
2 roc_auc binary 0.850 10 0.0160 Preprocessor1_Model1

Random Forest

rf_spec <-
rand_forest(mode = "classification", trees = 1000) |>
set_engine("ranger")
rf_wf <- workflow(titanic_formula, rf_spec)
rf_rs <- fit_resamples(
rf_wf,
resamples = titanic_folds,
control = contrl_preds)
collect_metrics(rf_rs)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.810 10 0.0150 Preprocessor1_Model1
2 roc_auc binary 0.871 10 0.0190 Preprocessor1_Model1

Decision Trees

dt_spec <-
decision_tree(mode = "classification") |>
set_engine("rpart")
dt_wf <- workflow(titanic_formula, dt_spec)
dt_rs <- fit_resamples(
dt_wf,
resamples = titanic_folds,
control = contrl_preds
)
collect_metrics(dt_rs)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.811 10 0.0129 Preprocessor1_Model1
2 roc_auc binary 0.813 10 0.0192 Preprocessor1_Model1

XGBoost

xgb_spec <-
boost_tree(mode = "classification") |>
set_engine("xgboost")
xgb_wf <- workflow(titanic_formula, xgb_spec)
xgb_rs <- fit_resamples(
xgb_wf,
resamples = titanic_folds,
control = contrl_preds
)
collect_metrics(xgb_rs)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.825 10 0.0163 Preprocessor1_Model1
2 roc_auc binary 0.842 10 0.0253 Preprocessor1_Model1

Linear Discriminant Analysis

lda_spec <-
discrim_linear(
mode = "classification",
engine = "MASS"
)
lda_wf <- workflow(titanic_formula, lda_spec)
lda_rs <- fit_resamples(
lda_wf,
resamples = titanic_folds,
control = contrl_preds
)
collect_metrics(lda_rs)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.804 10 0.0135 Preprocessor1_Model1
2 roc_auc binary 0.853 10 0.0188 Preprocessor1_Model1

None of the models were tuned, but the results are promising for a lot of the models. Decision Trees don’t seem to be doing well, but the other models are doing okay for the first non tuned run. The accuracy of the models as well as the precision are all above 80% which is a good sign.

Model Comparison

Comparing the models using the ROC curve

bind_rows(
collect_predictions(glm_rs) |>
mutate(mod = "Logistical Regression"),
collect_predictions(knn_rs) |>
mutate(mod = "K-Nearest Neighbors"),
collect_predictions(rf_rs) |>
mutate(mod = "Random Forest"),
collect_predictions(dt_rs) |>
mutate(mod = "Decision Trees"),
collect_predictions(xgb_rs) |>
mutate(mod = "XGBoost"),
collect_predictions(lda_rs) |>
mutate(mod = "Linear Dis")) |>
group_by(mod) |>
roc_curve(Survived, .pred_0) |>
autoplot() +
hrbrthemes::theme_ipsum()
ROC Curve for all 6 Models — Image by Author

It looks like the logistical regression model is one of the better models purely based on the ROC curve. So for the next step, I will use the logistical regression model to predict the survival of the passengers in the test dataset.

Model Selection

final_fitted <- last_fit(glm_wf, split_data)
collect_metrics(final_fitted)
# A tibble: 2 × 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.760 Preprocessor1_Model1
2 roc_auc binary 0.837 Preprocessor1_Model1

Prediction, Deployment & Submission

Taking the final fitted model and using it to predict the survival of the passengers in the training dataset as an example

# predict on the test data
final_wf <- extract_workflow(final_fitted)
predict(final_wf, titanic_train[59,])
# A tibble: 1 × 1
.pred_class
<fct>
1 0
titanic_train[59,]
# A tibble: 1 × 8
Survived Pclass Sex Age SibSp Parch Fare Embarked
<fct> <dbl> <fct> <dbl> <dbl> <dbl> <dbl> <fct>
1 1 1 male 28 0 0 26.6 S

Now I will use the final fitted model to predict the survival of the passengers in the combined dataset which includes both the training and testing data.

final_predict <- predict(final_wf, combined_data)
final_predict
# A tibble: 1,309 × 1
.pred_class
<fct>
1 0
2 1
3 1
4 1
5 0
6 0
7 0
8 0
9 1
10 1
# … with 1,299 more rows

We get the results so next we will combine the results with the test data and write it to a csv file.

Conclusion

The logistical regression model performed the best with an accuracy of 0.82 and therefore I will use this model to predict the survival of the passengers in the test dataset.

# combine the final predictions with the test data
final_submission <- combined_data %>%
mutate(Survived = final_predict$.pred_class) |>
filter(group == "testing")
submission_file <- final_submission |>
mutate(PassengerID = testing_data$PassengerId)
#write a csv file
write_csv(submission_file, "submission.csv")
final_submission
# A tibble: 418 × 9
Survived Pclass Sex Age SibSp Parch Fare Embarked group
<fct> <dbl> <fct> <dbl> <dbl> <dbl> <dbl> <fct> <chr>
1 0 3 male 34.5 0 0 7.83 Q testing
2 0 3 female 47 1 0 7 S testing
3 0 2 male 62 0 0 9.69 Q testing
4 0 3 male 27 0 0 8.66 S testing
5 1 3 female 22 1 1 12.3 S testing
6 0 3 male 14 0 0 9.22 S testing
7 1 3 female 30 0 0 7.63 Q testing
8 0 2 male 26 1 1 29 S testing
9 1 3 female 18 0 0 7.23 C testing
10 0 3 male 21 2 0 24.2 S testing
# … with 408 more rows
submission_file
# A tibble: 418 × 10
Survived Pclass Sex Age SibSp Parch Fare Embarked group PassengerID
<fct> <dbl> <fct> <dbl> <dbl> <dbl> <dbl> <fct> <chr> <dbl>
1 0 3 male 34.5 0 0 7.83 Q testing 892
2 0 3 female 47 1 0 7 S testing 893
3 0 2 male 62 0 0 9.69 Q testing 894
4 0 3 male 27 0 0 8.66 S testing 895
5 1 3 female 22 1 1 12.3 S testing 896
6 0 3 male 14 0 0 9.22 S testing 897
7 1 3 female 30 0 0 7.63 Q testing 898
8 0 2 male 26 1 1 29 S testing 899
9 1 3 female 18 0 0 7.23 C testing 900
10 0 3 male 21 2 0 24.2 S testing 901
# … with 408 more rows

The output of the submission file was submitted to Kaggle and the results were:

For the Logistical Regression model:

Leaderboard position: 8700, Accuracy: 0.76794

Total Submissions: 55,402

So we’re in the top 16%, which is a fairly decent result for a first attempt.

We could improve the model by doing some feature engineering, which I plan to do in the future. We could also try other models like SVM, Naive Bayes, etc. and see how they perform. and we could also try to tune the hyperparameters of the models to see if we can improve the accuracy. The models I used in this project are not tuned, so there is a lot of room for improvement. Plus, given the fact that the dataset is small, we could also try to use deep learning models like neural networks and see how they perform.

References

Some of the resources I used to learn about tidymodels and the Titanic dataset and how to use them together. Additionally, some of the resources I used to learn about the Titanic dataset where the authors used tidymodels to predict the survival of the passengers.

  • Tidymodels — Link
  • Predicting Titanic passenger survival using tidymodels | Niels van der Velden — Link
  • Tidymodels and the Titanic | the data diary — Link
  • Experimenting with machine learning in R with tidymodels and the Kaggle titanic dataset | Olivier Gimenez — Link
  • Tune XGBoost with tidymodels and #TidyTuesday beach volleyball — Link
  • Titanic — Machine Learning from Disaster — Link

Find me on Twitter: @karat_sidhu

--

--

No responses yet