library(tidymodels)
library(modeldata)
# Load data
data(cells)
set.seed(2369)
<- initial_split(cells %>% select(-case), prop = 3/4)
tr_te_split <- training(tr_te_split)
cell_train <- testing(tr_te_split)
cell_test
set.seed(1697)
<- vfold_cv(cell_train, v = 10) folds
Iterative Bayesian optimization of a classification model
Identify the best hyperparameters for a model using Bayesian optimization of iterative search.
Introduction
To use code in this article, you will need to install the following packages: kernlab, modeldata, themis, and tidymodels.
Many of the examples for model tuning focus on grid search. For that method, all the candidate tuning parameter combinations are defined prior to evaluation. Alternatively, iterative search can be used to analyze the existing tuning parameter results and then predict which tuning parameters to try next.
There are a variety of methods for iterative search and the focus in this article is on Bayesian optimization. For more information on this method, these resources might be helpful:
Practical bayesian optimization of machine learning algorithms (2012). J Snoek, H Larochelle, and RP Adams. Advances in neural information.
A Tutorial on Bayesian Optimization for Machine Learning (2018). R Adams.
Gaussian Processes for Machine Learning (2006). C E Rasmussen and C Williams.
Cell segmenting revisited
To demonstrate this approach to tuning models, let’s return to the cell segmentation data from the Getting Started article on resampling:
The tuning scheme
Since the predictors are highly correlated, we can used a recipe to convert the original predictors to principal component scores. There is also slight class imbalance in these data; about 64% of the data are poorly segmented. To mitigate this, the data will be down-sampled at the end of the pre-processing so that the number of poorly and well segmented cells occur with equal frequency. We can use a recipe for all this pre-processing, but the number of principal components will need to be tuned so that we have enough (but not too many) representations of the data.
library(themis)
<-
cell_pre_proc recipe(class ~ ., data = cell_train) %>%
step_YeoJohnson(all_predictors()) %>%
step_normalize(all_predictors()) %>%
step_pca(all_predictors(), num_comp = tune()) %>%
step_downsample(class)
In this analysis, we will use a support vector machine to model the data. Let’s use a radial basis function (RBF) kernel and tune its main parameter (\(\sigma\)). Additionally, the main SVM parameter, the cost value, also needs optimization.
<-
svm_mod svm_rbf(mode = "classification", cost = tune(), rbf_sigma = tune()) %>%
set_engine("kernlab")
These two objects (the recipe and model) will be combined into a single object via the workflow()
function from the workflows package; this object will be used in the optimization process.
<-
svm_wflow workflow() %>%
add_model(svm_mod) %>%
add_recipe(cell_pre_proc)
From this object, we can derive information about what parameters are slated to be tuned. A parameter set is derived by:
<- extract_parameter_set_dials(svm_wflow)
svm_set
svm_set#> Collection of 3 parameters for tuning
#>
#> identifier type object
#> cost cost nparam[+]
#> rbf_sigma rbf_sigma nparam[+]
#> num_comp num_comp nparam[+]
The default range for the number of PCA components is rather small for this data set. A member of the parameter set can be modified using the update()
function. Let’s constrain the search to one to twenty components by updating the num_comp
parameter. Additionally, the lower bound of this parameter is set to zero which specifies that the original predictor set should also be evaluated (i.e., with no PCA step at all):
<-
svm_set %>%
svm_set update(num_comp = num_comp(c(0L, 20L)))
Sequential tuning
Bayesian optimization is a sequential method that uses a model to predict new candidate parameters for assessment. When scoring potential parameter value, the mean and variance of performance are predicted. The strategy used to define how these two statistical quantities are used is defined by an acquisition function.
For example, one approach for scoring new candidates is to use a confidence bound. Suppose accuracy is being optimized. For a metric that we want to maximize, a lower confidence bound can be used. The multiplier on the standard error (denoted as \(\kappa\)) is a value that can be used to make trade-offs between exploration and exploitation.
Exploration means that the search will consider candidates in untested space.
Exploitation focuses in areas where the previous best results occurred.
The variance predicted by the Bayesian model is mostly spatial variation; the value will be large for candidate values that are not close to values that have already been evaluated. If the standard error multiplier is high, the search process will be more likely to avoid areas without candidate values in the vicinity.
We’ll use another acquisition function, expected improvement, that determines which candidates are likely to be helpful relative to the current best results. This is the default acquisition function. More information on these functions can be found in the package vignette for acquisition functions.
set.seed(12)
<-
search_res %>%
svm_wflow tune_bayes(
resamples = folds,
# To use non-default parameter ranges
param_info = svm_set,
# Generate five at semi-random to start
initial = 5,
iter = 50,
# How to measure performance?
metrics = metric_set(roc_auc),
control = control_bayes(no_improve = 30, verbose = TRUE)
)#>
#> ❯ Generating a set of 5 initial parameter results
#> ✓ Initialization complete
#>
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> i Gaussian process model
#> ✓ Gaussian process model
#> i Generating 5000 candidates
#> i Predicted candidates
#> i Estimating performance
#> ✓ Estimating performance
#> ! No improvement for 30 iterations; returning current results.
The resulting tibble is a stacked set of rows of the rsample object with an additional column for the iteration number:
search_res#> # Tuning results
#> # 10-fold cross-validation
#> # A tibble: 410 × 5
#> splits id .metrics .notes .iter
#> <list> <chr> <list> <list> <int>
#> 1 <split [1362/152]> Fold01 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> 2 <split [1362/152]> Fold02 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> 3 <split [1362/152]> Fold03 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> 4 <split [1362/152]> Fold04 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> 5 <split [1363/151]> Fold05 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> 6 <split [1363/151]> Fold06 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> 7 <split [1363/151]> Fold07 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> 8 <split [1363/151]> Fold08 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> 9 <split [1363/151]> Fold09 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> 10 <split [1363/151]> Fold10 <tibble [5 × 7]> <tibble [0 × 3]> 0
#> # ℹ 400 more rows
As with grid search, we can summarize the results over resamples:
<-
estimates collect_metrics(search_res) %>%
arrange(.iter)
estimates#> # A tibble: 45 × 10
#> cost rbf_sigma num_comp .metric .estimator mean n std_err .config
#> <dbl> <dbl> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 0.233 6.26e-7 6 roc_auc binary 0.346 10 0.115 Prepro…
#> 2 2.74 3.06e-4 17 roc_auc binary 0.873 10 0.0128 Prepro…
#> 3 0.00140 3.26e-9 14 roc_auc binary 0.237 10 0.0586 Prepro…
#> 4 0.0185 4.40e-2 8 roc_auc binary 0.881 10 0.0119 Prepro…
#> 5 4.36 7.49e-5 1 roc_auc binary 0.773 10 0.0106 Prepro…
#> 6 0.114 3.39e-3 9 roc_auc binary 0.875 10 0.0122 Iter1
#> 7 31.8 3.82e-2 15 roc_auc binary 0.855 10 0.0137 Iter2
#> 8 0.00176 9.67e-1 5 roc_auc binary 0.872 10 0.0124 Iter3
#> 9 0.00115 1.20e-3 19 roc_auc binary 0.347 10 0.115 Iter4
#> 10 0.306 2.97e-1 13 roc_auc binary 0.868 10 0.0126 Iter5
#> # ℹ 35 more rows
#> # ℹ 1 more variable: .iter <int>
The best performance of the initial set of candidate values was AUC = 0.8805091
. The best results were achieved at iteration 10 with a corresponding AUC value of 0.8917785. The five best results are:
show_best(search_res, metric = "roc_auc")
#> # A tibble: 5 × 10
#> cost rbf_sigma num_comp .metric .estimator mean n std_err .config .iter
#> <dbl> <dbl> <int> <chr> <chr> <dbl> <int> <dbl> <chr> <int>
#> 1 0.935 0.0223 11 roc_auc binary 0.892 10 0.0103 Iter10 10
#> 2 1.23 0.0213 13 roc_auc binary 0.892 10 0.0103 Iter22 22
#> 3 0.806 0.0767 17 roc_auc binary 0.890 10 0.0101 Iter37 37
#> 4 0.301 0.0892 13 roc_auc binary 0.889 10 0.0110 Iter40 40
#> 5 0.141 0.0448 9 roc_auc binary 0.889 10 0.0112 Iter19 19
A plot of the search iterations can be created via:
autoplot(search_res, type = "performance")
There are many parameter combinations have roughly equivalent results.
How did the parameters change over iterations?
autoplot(search_res, type = "parameters") +
labs(x = "Iterations", y = NULL)
Session information
#> ─ Session info ─────────────────────────────────────────────────────
#> setting value
#> version R version 4.4.0 (2024-04-24)
#> os macOS Sonoma 14.4.1
#> system aarch64, darwin20
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz America/Los_Angeles
#> date 2024-06-26
#> pandoc 3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#>
#> ─ Packages ─────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> broom * 1.0.6 2024-05-17 [1] CRAN (R 4.4.0)
#> dials * 1.2.1 2024-02-22 [1] CRAN (R 4.4.0)
#> dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.4.0)
#> ggplot2 * 3.5.1 2024-04-23 [1] CRAN (R 4.4.0)
#> infer * 1.0.7 2024-03-25 [1] CRAN (R 4.4.0)
#> kernlab * 0.9-32 2023-01-31 [1] CRAN (R 4.4.0)
#> modeldata * 1.4.0 2024-06-19 [1] CRAN (R 4.4.0)
#> parsnip * 1.2.1 2024-03-22 [1] CRAN (R 4.4.0)
#> purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.4.0)
#> recipes * 1.0.10 2024-02-18 [1] CRAN (R 4.4.0)
#> rlang * 1.1.4 2024-06-04 [1] CRAN (R 4.4.0)
#> rsample * 1.2.1 2024-03-25 [1] CRAN (R 4.4.0)
#> themis * 1.0.2 2023-08-14 [1] CRAN (R 4.4.0)
#> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.4.0)
#> tidymodels * 1.2.0 2024-03-25 [1] CRAN (R 4.4.0)
#> tune * 1.2.1 2024-04-18 [1] CRAN (R 4.4.0)
#> workflows * 1.1.4 2024-02-19 [1] CRAN (R 4.4.0)
#> yardstick * 1.3.1 2024-03-21 [1] CRAN (R 4.4.0)
#>
#> [1] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#>
#> ────────────────────────────────────────────────────────────────────