Getting to know tidymodels

R
Published

August 4, 2023

Which packages belong to the metapackage tidymodels

tidymodels::tidymodels_packages(include_self = TRUE)
 [1] "broom"        "cli"          "conflicted"   "dials"        "dplyr"       
 [6] "ggplot2"      "hardhat"      "infer"        "modeldata"    "parsnip"     
[11] "purrr"        "recipes"      "rlang"        "rsample"      "rstudioapi"  
[16] "tibble"       "tidyr"        "tune"         "workflows"    "workflowsets"
[21] "yardstick"    "tidymodels"  

Tutorial 1 - palmerpenguins dataset

set.seed(1)
library('tidyverse')
── Attaching core tidyverse packages ──────────────────────── tidyverse 2.0.0 ──
✔ dplyr     1.1.2     ✔ readr     2.1.4
✔ forcats   1.0.0     ✔ stringr   1.5.0
✔ ggplot2   3.4.2     ✔ tibble    3.2.1
✔ lubridate 1.9.2     ✔ tidyr     1.3.0
✔ purrr     1.0.1     
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag()    masks stats::lag()
ℹ Use the conflicted package (<http://conflicted.r-lib.org/>) to force all conflicts to become errors
library('palmerpenguins')

penguins_df <- penguins %>%
  filter(!is.na(sex)) %>%
  select(-year, -island)
library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.1.0 ──
✔ broom        1.0.4     ✔ rsample      1.1.1
✔ dials        1.2.0     ✔ tune         1.1.1
✔ infer        1.0.4     ✔ workflows    1.1.3
✔ modeldata    1.1.0     ✔ workflowsets 1.0.1
✔ parsnip      1.1.0     ✔ yardstick    1.2.0
✔ recipes      1.0.6     
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ scales::discard() masks purrr::discard()
✖ dplyr::filter()   masks stats::filter()
✖ recipes::fixed()  masks stringr::fixed()
✖ dplyr::lag()      masks stats::lag()
✖ yardstick::spec() masks readr::spec()
✖ recipes::step()   masks stats::step()
• Search for functions across packages at https://www.tidymodels.org/find/
glimpse(penguins_df)
Rows: 333
Columns: 6
$ species           <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Adelie, Adel…
$ bill_length_mm    <dbl> 39.1, 39.5, 40.3, 36.7, 39.3, 38.9, 39.2, 41.1, 38.6…
$ bill_depth_mm     <dbl> 18.7, 17.4, 18.0, 19.3, 20.6, 17.8, 19.6, 17.6, 21.2…
$ flipper_length_mm <int> 181, 186, 195, 193, 190, 181, 195, 182, 191, 198, 18…
$ body_mass_g       <int> 3750, 3800, 3250, 3450, 3650, 3625, 4675, 3200, 3800…
$ sex               <fct> male, female, female, female, male, female, male, fe…
penguin_split <- initial_split(penguins_df, strata = sex)
str(penguin_split)
List of 4
 $ data  : tibble [333 × 6] (S3: tbl_df/tbl/data.frame)
  ..$ species          : Factor w/ 3 levels "Adelie","Chinstrap",..: 1 1 1 1 1 1 1 1 1 1 ...
  ..$ bill_length_mm   : num [1:333] 39.1 39.5 40.3 36.7 39.3 38.9 39.2 41.1 38.6 34.6 ...
  ..$ bill_depth_mm    : num [1:333] 18.7 17.4 18 19.3 20.6 17.8 19.6 17.6 21.2 21.1 ...
  ..$ flipper_length_mm: int [1:333] 181 186 195 193 190 181 195 182 191 198 ...
  ..$ body_mass_g      : int [1:333] 3750 3800 3250 3450 3650 3625 4675 3200 3800 4400 ...
  ..$ sex              : Factor w/ 2 levels "female","male": 2 1 1 1 2 1 2 1 2 2 ...
 $ in_id : int [1:249] 2 3 4 11 12 18 24 26 28 33 ...
 $ out_id: logi NA
 $ id    : tibble [1 × 1] (S3: tbl_df/tbl/data.frame)
  ..$ id: chr "Resample1"
 - attr(*, "class")= chr [1:3] "initial_split" "mc_split" "rsplit"
penguin_train <- training(penguin_split)
str(penguin_train)
tibble [249 × 6] (S3: tbl_df/tbl/data.frame)
 $ species          : Factor w/ 3 levels "Adelie","Chinstrap",..: 1 1 1 1 1 1 1 1 1 1 ...
 $ bill_length_mm   : num [1:249] 39.5 40.3 36.7 36.6 38.7 35.9 37.9 39.5 39.5 42.2 ...
 $ bill_depth_mm    : num [1:249] 17.4 18 19.3 17.8 19 19.2 18.6 16.7 17.8 18.5 ...
 $ flipper_length_mm: int [1:249] 186 195 193 185 195 189 172 178 188 180 ...
 $ body_mass_g      : int [1:249] 3800 3250 3450 3700 3450 3800 3150 3250 3300 3550 ...
 $ sex              : Factor w/ 2 levels "female","male": 1 1 1 1 1 1 1 1 1 1 ...
penguin_test <- testing(penguin_split)
set.seed(123)
penguin_boot <- bootstraps(penguin_train)
penguin_boot
# Bootstrap sampling 
# A tibble: 25 × 2
   splits           id         
   <list>           <chr>      
 1 <split [249/93]> Bootstrap01
 2 <split [249/91]> Bootstrap02
 3 <split [249/90]> Bootstrap03
 4 <split [249/91]> Bootstrap04
 5 <split [249/85]> Bootstrap05
 6 <split [249/87]> Bootstrap06
 7 <split [249/94]> Bootstrap07
 8 <split [249/88]> Bootstrap08
 9 <split [249/95]> Bootstrap09
10 <split [249/89]> Bootstrap10
# ℹ 15 more rows

Question:

# how can I extract a single bootstrap sample?
x <- penguin_boot |> 
  select(splits) |> 
  slice(2) |>
  pull()

x[[1]]$data
# A tibble: 249 × 6
   species bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex   
   <fct>            <dbl>         <dbl>             <int>       <int> <fct> 
 1 Adelie            39.5          17.4               186        3800 female
 2 Adelie            40.3          18                 195        3250 female
 3 Adelie            36.7          19.3               193        3450 female
 4 Adelie            36.6          17.8               185        3700 female
 5 Adelie            38.7          19                 195        3450 female
 6 Adelie            35.9          19.2               189        3800 female
 7 Adelie            37.9          18.6               172        3150 female
 8 Adelie            39.5          16.7               178        3250 female
 9 Adelie            39.5          17.8               188        3300 female
10 Adelie            42.2          18.5               180        3550 female
# ℹ 239 more rows
x <- penguin_boot |> 
  filter(id == 'Bootstrap01') |> 
  select(splits) |> 
  pull()

x[[1]]$data
# A tibble: 249 × 6
   species bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex   
   <fct>            <dbl>         <dbl>             <int>       <int> <fct> 
 1 Adelie            39.5          17.4               186        3800 female
 2 Adelie            40.3          18                 195        3250 female
 3 Adelie            36.7          19.3               193        3450 female
 4 Adelie            36.6          17.8               185        3700 female
 5 Adelie            38.7          19                 195        3450 female
 6 Adelie            35.9          19.2               189        3800 female
 7 Adelie            37.9          18.6               172        3150 female
 8 Adelie            39.5          16.7               178        3250 female
 9 Adelie            39.5          17.8               188        3300 female
10 Adelie            42.2          18.5               180        3550 female
# ℹ 239 more rows
glm_spec <- logistic_reg() %>%
  set_engine("glm")
rf_spec <- rand_forest() %>%
  set_mode("classification") %>%
  set_engine("ranger")

rf_spec
Random Forest Model Specification (classification)

Computational engine: ranger 
penguin_wf <- workflow() %>%
  add_formula(sex ~ .)

penguin_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Formula
Model: None

── Preprocessor ────────────────────────────────────────────────────────────────
sex ~ .
glm_rs <- penguin_wf %>%
  add_model(glm_spec) %>%
  fit_resamples(
    resamples = penguin_boot,
    control = control_resamples(save_pred = TRUE)
  )
→ A | warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
There were issues with some computations   A: x1
There were issues with some computations   A: x2
glm_rs
# Resampling results
# Bootstrap sampling 
# A tibble: 25 × 5
   splits           id          .metrics         .notes           .predictions
   <list>           <chr>       <list>           <list>           <list>      
 1 <split [249/93]> Bootstrap01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 2 <split [249/91]> Bootstrap02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 3 <split [249/90]> Bootstrap03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 4 <split [249/91]> Bootstrap04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 5 <split [249/85]> Bootstrap05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 6 <split [249/87]> Bootstrap06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 7 <split [249/94]> Bootstrap07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 8 <split [249/88]> Bootstrap08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 9 <split [249/95]> Bootstrap09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
10 <split [249/89]> Bootstrap10 <tibble [2 × 4]> <tibble [1 × 3]> <tibble>    
# ℹ 15 more rows

There were issues with some computations:

  - Warning(s) x2: glm.fit: fitted probabilities numerically 0 or 1 occurred

Run `show_notes(.Last.tune.result)` for more information.
rf_rs <- penguin_wf %>%
  add_model(rf_spec) %>%
  fit_resamples(
    resamples = penguin_boot,
    control = control_resamples(save_pred = TRUE)
  )

rf_rs
# Resampling results
# Bootstrap sampling 
# A tibble: 25 × 5
   splits           id          .metrics         .notes           .predictions
   <list>           <chr>       <list>           <list>           <list>      
 1 <split [249/93]> Bootstrap01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 2 <split [249/91]> Bootstrap02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 3 <split [249/90]> Bootstrap03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 4 <split [249/91]> Bootstrap04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 5 <split [249/85]> Bootstrap05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 6 <split [249/87]> Bootstrap06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 7 <split [249/94]> Bootstrap07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 8 <split [249/88]> Bootstrap08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
 9 <split [249/95]> Bootstrap09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
10 <split [249/89]> Bootstrap10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble>    
# ℹ 15 more rows
collect_metrics(rf_rs)
# A tibble: 2 × 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.899    25 0.00704 Preprocessor1_Model1
2 roc_auc  binary     0.971    25 0.00307 Preprocessor1_Model1
collect_metrics(glm_rs)
# A tibble: 2 × 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.919    25 0.00527 Preprocessor1_Model1
2 roc_auc  binary     0.977    25 0.00220 Preprocessor1_Model1
glm_rs %>%
  conf_mat_resampled()
# A tibble: 4 × 3
  Prediction Truth   Freq
  <fct>      <fct>  <dbl>
1 female     female 41.8 
2 female     male    3.72
3 male       female  3.68
4 male       male   41.6 
glm_rs %>%
  collect_predictions() %>%
  group_by(id) %>%
  roc_curve(sex, .pred_female) %>%
  ggplot(aes(1 - specificity, sensitivity, color = id)) +
  geom_abline(lty = 2, color = "gray80", linewidth = 1.5) +
  geom_path(show.legend = FALSE, alpha = 0.6, linewidth = 1.2) +
  coord_equal()

penguin_final <- penguin_wf %>%
  add_model(glm_spec) %>%
  last_fit(penguin_split)

penguin_final
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits           id               .metrics .notes   .predictions .workflow 
  <list>           <chr>            <list>   <list>   <list>       <list>    
1 <split [249/84]> train/test split <tibble> <tibble> <tibble>     <workflow>
collect_metrics(penguin_final)
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.905 Preprocessor1_Model1
2 roc_auc  binary         0.966 Preprocessor1_Model1
collect_predictions(penguin_final) %>%
  conf_mat(sex, .pred_class)
          Truth
Prediction female male
    female     39    5
    male        3   37
penguin_final$.workflow[[1]] %>%
  tidy(exponentiate = TRUE)
# A tibble: 7 × 5
  term              estimate std.error statistic     p.value
  <chr>                <dbl>     <dbl>     <dbl>       <dbl>
1 (Intercept)       2.35e-42  18.1        -5.30  0.000000115
2 speciesChinstrap  9.76e- 4   1.91       -3.62  0.000295   
3 speciesGentoo     1.96e- 4   3.25       -2.63  0.00866    
4 bill_length_mm    1.88e+ 0   0.159       4.00  0.0000639  
5 bill_depth_mm     7.50e+ 0   0.455       4.43  0.00000940 
6 flipper_length_mm 1.06e+ 0   0.0653      0.863 0.388      
7 body_mass_g       1.01e+ 0   0.00138     4.63  0.00000374 
penguins %>%
  filter(!is.na(sex)) %>%
  ggplot(aes(bill_depth_mm, bill_length_mm, color = sex, size = body_mass_g)) +
  geom_point(alpha = 0.5) +
  facet_wrap(~species)