Tuning text models

tuning
regression

Prepare text data for predictive modeling and tune with both grid and iterative search.

Introduction

To use code in this article, you will need to install the following packages: stopwords, textrecipes, and tidymodels.

This article demonstrates an advanced example for training and tuning models for text data. Text data must be processed and transformed to a numeric representation to be ready for computation in modeling; in tidymodels, we use a recipe for this preprocessing. This article also shows how to extract information from each model fit during tuning to use later on.

Text as data

The text data we’ll use in this article are from Amazon:

This dataset consists of reviews of fine foods from amazon. The data span a period of more than 10 years, including all ~500,000 reviews up to October 2012. Reviews include product and user information, ratings, and a plaintext review.

This article uses a small subset of the total reviews available at the original source. We sampled a single review from 5,000 random products and allocated 80% of these data to the training set, with the remaining 1,000 reviews held out for the test set.

There is a column for the product, a column for the text of the review, and a factor column for the outcome variable. The outcome is whether the reviewer gave the product a five-star rating or not.

library(tidymodels)

data("small_fine_foods")
training_data
#> # A tibble: 4,000 × 3
#>    product    review                                                       score
#>    <chr>      <chr>                                                        <fct>
#>  1 B000J0LSBG "this stuff is  not stuffing  its  not good at all  save yo… other
#>  2 B000EYLDYE "I absolutely LOVE this dried fruit.  LOVE IT.  Whenever I … great
#>  3 B0026LIO9A "GREAT DEAL, CONVENIENT TOO.  Much cheaper than WalMart and… great
#>  4 B00473P8SK "Great flavor, we go through a ton of this sauce! I discove… great
#>  5 B001SAWTNM "This is excellent salsa/hot sauce, but you can get it for … great
#>  6 B000FAG90U "Again, this is the best dogfood out there.  One suggestion… great
#>  7 B006BXTCEK "The box I received was filled with teas, hot chocolates, a… other
#>  8 B002GWH5OY "This is delicious coffee which compares favorably with muc… great
#>  9 B003R0MFYY "Don't let these little tiny cans fool you.  They pack a lo… great
#> 10 B001EO5ZXI "One of the nicest, smoothest cup of chai I've made. Nice m… great
#> # ℹ 3,990 more rows

Our modeling goal is to create modeling features from the text of the reviews to predict whether the review was five-star or not.

Resampling

There is enough data here so that 5-fold resampling would hold out 800 reviews at a time to estimate performance. Performance estimates using this many observations have sufficiently low noise to measure and tune models.

set.seed(8935)
folds <- vfold_cv(training_data, v = 5)
folds
#> #  5-fold cross-validation 
#> # A tibble: 5 × 2
#>   splits             id   
#>   <list>             <chr>
#> 1 <split [3200/800]> Fold1
#> 2 <split [3200/800]> Fold2
#> 3 <split [3200/800]> Fold3
#> 4 <split [3200/800]> Fold4
#> 5 <split [3200/800]> Fold5

Extracted results

Let’s return to the grid search results and examine the results of our extract function. For each fitted model, a tibble was saved that contains the relationship between the number of predictors and the penalty value. Let’s look at these results for the best model:

params <- select_best(five_star_glmnet, metric = "roc_auc")
params
#> # A tibble: 1 × 4
#>   penalty mixture num_terms .config               
#>     <dbl>   <dbl>     <dbl> <chr>                 
#> 1  0.0379    0.25      4096 Preprocessor3_Model031

Recall that we saved the glmnet results in a tibble. The column five_star_glmnet$.extracts is a list of tibbles. As an example, the first element of the list is:

five_star_glmnet$.extracts[[1]]
#> # A tibble: 300 × 5
#>    num_terms penalty mixture .extracts          .config               
#>        <dbl>   <dbl>   <dbl> <list>             <chr>                 
#>  1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model001
#>  2       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model002
#>  3       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model003
#>  4       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model004
#>  5       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model005
#>  6       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model006
#>  7       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model007
#>  8       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model008
#>  9       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model009
#> 10       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model010
#> # ℹ 290 more rows

More nested tibbles! Let’s unnest() the five_star_glmnet$.extracts column:

library(tidyr)
extracted <- 
  five_star_glmnet %>% 
  dplyr::select(id, .extracts) %>% 
  unnest(cols = .extracts)
extracted
#> # A tibble: 1,500 × 6
#>    id    num_terms penalty mixture .extracts          .config               
#>    <chr>     <dbl>   <dbl>   <dbl> <list>             <chr>                 
#>  1 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model001
#>  2 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model002
#>  3 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model003
#>  4 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model004
#>  5 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model005
#>  6 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model006
#>  7 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model007
#>  8 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model008
#>  9 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model009
#> 10 Fold1       256       1    0.01 <tibble [100 × 2]> Preprocessor1_Model010
#> # ℹ 1,490 more rows

One thing to realize here is that tune_grid() may not fit all of the models that are evaluated. In this case, for each value of mixture and num_terms, the model is fit over all penalty values (this is a feature of this particular model and is not generally true for other engines). To select the best parameter set, we can exclude the penalty column in extracted:

extracted <- 
  extracted %>% 
  dplyr::select(-penalty) %>% 
  inner_join(params, by = c("num_terms", "mixture")) %>% 
  # Now remove it from the final results
  dplyr::select(-penalty)
extracted
#> # A tibble: 100 × 6
#>    id    num_terms mixture .extracts          .config.x              .config.y  
#>    <chr>     <dbl>   <dbl> <list>             <chr>                  <chr>      
#>  1 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model021 Preprocess…
#>  2 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model022 Preprocess…
#>  3 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model023 Preprocess…
#>  4 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model024 Preprocess…
#>  5 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model025 Preprocess…
#>  6 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model026 Preprocess…
#>  7 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model027 Preprocess…
#>  8 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model028 Preprocess…
#>  9 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model029 Preprocess…
#> 10 Fold1      4096    0.25 <tibble [100 × 2]> Preprocessor3_Model030 Preprocess…
#> # ℹ 90 more rows

Now we can get at the results that we want using another unnest():

extracted <- 
  extracted %>% 
  unnest(col = .extracts) # <- these contain a `penalty` column
extracted
#> # A tibble: 10,000 × 7
#>    id    num_terms mixture penalty num_vars .config.x              .config.y    
#>    <chr>     <dbl>   <dbl>   <dbl>    <int> <chr>                  <chr>        
#>  1 Fold1      4096    0.25   0.352        0 Preprocessor3_Model021 Preprocessor…
#>  2 Fold1      4096    0.25   0.336        2 Preprocessor3_Model021 Preprocessor…
#>  3 Fold1      4096    0.25   0.321        2 Preprocessor3_Model021 Preprocessor…
#>  4 Fold1      4096    0.25   0.306        2 Preprocessor3_Model021 Preprocessor…
#>  5 Fold1      4096    0.25   0.292        2 Preprocessor3_Model021 Preprocessor…
#>  6 Fold1      4096    0.25   0.279        2 Preprocessor3_Model021 Preprocessor…
#>  7 Fold1      4096    0.25   0.266        3 Preprocessor3_Model021 Preprocessor…
#>  8 Fold1      4096    0.25   0.254        5 Preprocessor3_Model021 Preprocessor…
#>  9 Fold1      4096    0.25   0.243        7 Preprocessor3_Model021 Preprocessor…
#> 10 Fold1      4096    0.25   0.232        7 Preprocessor3_Model021 Preprocessor…
#> # ℹ 9,990 more rows

Let’s look at a plot of these results (per resample):

ggplot(extracted, aes(x = penalty, y = num_vars)) + 
  geom_line(aes(group = id, col = id), alpha = .5) + 
  ylab("Number of retained predictors") + 
  scale_x_log10()  + 
  ggtitle(paste("mixture = ", params$mixture, "and", params$num_terms, "features")) + 
  theme(legend.position = "none")

These results might help guide the choice of the penalty range if more optimization was conducted.

Session information

#> ─ Session info ─────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.0 (2024-04-24)
#>  os       macOS Sonoma 14.4.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       America/Los_Angeles
#>  date     2024-06-26
#>  pandoc   3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ─────────────────────────────────────────────────────────
#>  package     * version date (UTC) lib source
#>  broom       * 1.0.6   2024-05-17 [1] CRAN (R 4.4.0)
#>  dials       * 1.2.1   2024-02-22 [1] CRAN (R 4.4.0)
#>  dplyr       * 1.1.4   2023-11-17 [1] CRAN (R 4.4.0)
#>  ggplot2     * 3.5.1   2024-04-23 [1] CRAN (R 4.4.0)
#>  infer       * 1.0.7   2024-03-25 [1] CRAN (R 4.4.0)
#>  parsnip     * 1.2.1   2024-03-22 [1] CRAN (R 4.4.0)
#>  purrr       * 1.0.2   2023-08-10 [1] CRAN (R 4.4.0)
#>  recipes     * 1.0.10  2024-02-18 [1] CRAN (R 4.4.0)
#>  rlang         1.1.4   2024-06-04 [1] CRAN (R 4.4.0)
#>  rsample     * 1.2.1   2024-03-25 [1] CRAN (R 4.4.0)
#>  stopwords   * 2.3     2021-10-28 [1] CRAN (R 4.4.0)
#>  textrecipes * 1.0.6   2023-11-15 [1] CRAN (R 4.4.0)
#>  tibble      * 3.2.1   2023-03-20 [1] CRAN (R 4.4.0)
#>  tidymodels  * 1.2.0   2024-03-25 [1] CRAN (R 4.4.0)
#>  tune        * 1.2.1   2024-04-18 [1] CRAN (R 4.4.0)
#>  workflows   * 1.1.4   2024-02-19 [1] CRAN (R 4.4.0)
#>  yardstick   * 1.3.1   2024-03-21 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#> ────────────────────────────────────────────────────────────────────
Resources
Explore searchable tables of all tidymodels packages and functions.
Study up on statistics and modeling with our comprehensive books.
Hear the latest about tidymodels packages at the tidyverse blog.