Введение в машинное обучение на R

Base R, Tidymodels

2025

Учебные вопросы

  1. Основные статистические характеристики
  2. Создание моделей машинного обучения
  3. R tidymodels
  4. Решение задач машинного обучения
    • кластеризация
    • классификация
    • регрессия

Выборка

  • Генеральная совокупность (statistaical population)
  • Выборка (sample)

Центральная предельная теорема – Central Limit Theorem

Если много раз выбирать данные из генеральной совокупности, то распределение средних (среднее значение всех средних) нашего множества выборок будет стремиться к среднему генеральной совокупности.

Статистические характеристики

iris <- iris
iris$Sepal.Length |>
    head(n = 15)
 [1] 5.1 4.9 4.7 4.6 5.0 5.4 4.6 5.0 4.4 4.9 5.4 4.8 4.8 4.3 5.8
summary(iris$Sepal.Length)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
  4.300   5.100   5.800   5.843   6.400   7.900 

Доверительный интервал

Мы не можем быть на 100 процентов уверены, что знаем параметр iris$Sepal.Length генеральной совокупности

Различие в статистических характеристиках

Доверительный интервал

Для нормального распределения

\[x_{med}\pm{z}\frac{\sigma}{\sqrt{n}}\], где \(x_{med}\) – среднее, \(z\) – стандартизированная оценка, \(\sigma\) – СКО, \(n\) – мощность выборки.

confidence_interval <- function(vector, interval) {
  vec_sd <- sd(vector)
  n <- length(vector)
  vec_mean <- mean(vector)
  error <- qt((interval + 1)/2, df = n - 1) * vec_sd / sqrt(n)
  result <- c("lower" = vec_mean - error, "upper" = vec_mean + error)
  return(result)
}

Вычисление доверительного интервала

confidence_interval(iris$Sepal.Length, 0.90)
   lower    upper 
5.731427 5.955240 
confidence_interval(iris$Sepal.Length, 0.95)
   lower    upper 
5.709732 5.976934 
confidence_interval(iris$Sepal.Length, 0.99)
   lower    upper 
5.666920 6.019747 

Подготовка данных

flowchart TD
A[Данные] --> B[Тренировочные -- train]
A -.-> C[Валидационные -- dev]
A --> D[Тестовые -- test]

Для чего?

Контроль результата обучения модели

Возможнве проблемы

  1. Низкокачественные данные: выбросы, ошибки
  2. Переобучение: когда входных данных недостаточно, но построенная модель хорошо объясняет параметры из обучающей выборки, то считается, что любой выброс или колебания приводят к недостоверным прогнозам.
  3. Несбалансированный набор данных (дисбаланс классов). Исследовательский анализ данных (EDA) поможет выявить несбалансированные данные.

Переобучение и недообучение

Переобучение

явление, при котором алгоритм слишком приспособлен для данных, на которых он обучался. Переобучение имеет место при выборе слишком сложных моделей (model complexity).

Недообучение

явление, обратное переобучению, при котором алгоритм не полностью использует предоставленные ему для обучения данные. Недообучение имеет место при выборе недостаточно сложных моделей.

Как возникает переобучение

  1. Избыточная сложность модели – избыточно много весов, которые “расходуются” на подгонку к имеющимся данным
  2. Переобучение есть всегда – обучение происходит по заведомо неполной выборке из генеральной совокупности

Как обнаружить переобучение

Если качество на тестовой выборке сильно хуже качества на обучающих данных — у нас переобучение

Линейная регрессия

Общий вид

\[ a(x)=w_0+w_1 \cdot x_1+w_2 \cdot x_2 + \ldots + w_n \cdot x_n \] где \(x_1, \ldots , x_n\) – признаки объекта \(X\)

Сокращенная запись

\[ a(x)=w_0 + \sum_{j=1}^{n} w_j x_j \]

Обучение = минимизация среднеквадратической ошибки (СКО)

\[ Q(a,X) = \frac {1} {l} \sum_{i=1}^{l} (a(x_i) - y_i)^2 = \frac {1} {l} \sum_{i=1}^{l} ((w_i,x_i) - y_i)^2 \to \min {w} \]

Tidymodels

https://www.tidymodels.org/packages/#core-tidymodels

  • rsample – деление данных
  • parsnip – унифицированный интерфейс для моделей
  • infer – dplyr-like функции тестирования гипотез
  • recipes – препроцессор для конструирования параметров
  • dials – тюнинг параметров моделей
  • yardstick – оценка моделей
  • workflows – выстраивание многоэтапных процессов построения моделей

Rsample

https://rsample.tidymodels.org/

Задачи:

  • Создание сэмплов для различных этапов исследования данных и машинного обучения
  • Сравнение распределений параметров в полученных сэмплах
  • Оценка полученных моделей

Parsnip

https://parsnip.tidymodels.org/

Интерфейс взаимодействия с моделью, а не новая имплементация существующих пакетов.

Задачи:

  • Разделение этапов задания параметров модели от процесса обучения и инференса
  • Разделение спецификации модели от ее конкретной реализации (R, Spark, Keras, Tensorflow…)
  • Унификация названий параметров среди разных реализаций моделей (RF, boosting, bagging…)

Parsnip

  • Логистическая регрессия
  • Дерево решений
  • Случайный лес
  • Градиентный бустинг (XGBoost, LightGBM)
  • Метод опорных векторов (SVM)
  • Наивный Байесовский классификатор
  • K-ближайших соседей (KNN)
  • Дискриминантный анализ (LDA/QDA)
  • Нейронные сети
  • Ансамбли моделей – пакеты stacks и workflowset

https://parsnip.tidymodels.org/reference/index.html – Models

Recipes

https://recipes.tidymodels.org/

Задачи:

  • feature engineering – спецификация этапов подготовки данных (признаков) для дальнейшего построения модели (аналог dplyr)
  • устранение дупликации данных для тестирования разных моделей машинного обучения с разными требованиями к загружаемым в них данным.

Workflows

https://workflows.tidymodels.org/

Задачи:

  • Построение пайплайна предобработки данных и моделирования

flowchart TD
    A[Рецепт] --> C[Workflow]
    B[Модель с параметрами] --> C

Кластеризация

KNN – k-nearest neighbors algorithm – Метод k-ближайших соседей

Регрессия

  • Выявление закономерностей в распределении данных
  • Выявление зависимостей в данных
  • Построение математической (статистической) модели

Корреляционный анализ

  • Есть ли зависимость между различными переменными?
  • Какой вид этой зависимости?
  • Есть ли причинно-следственная связь?

Корреляция

(лат. correlatio – “соотношение”) – корреляционная зависимость – статистическая взаимосвязь двух или более СВ, при этом изменения значений одной или нескольких из этих величин сопутствуют систематическому изменению значений другой или других величин.

cor(iris$Sepal.Length, iris$Petal.Length)
[1] 0.8717538

Кореляционная матрица

cor(iris)

Ошибка выполнения

 

library(dplyr)

iris %>% 
  select(1:4) %>% 
  cor()
             Sepal.Length Sepal.Width Petal.Length Petal.Width
Sepal.Length    1.0000000  -0.1175698    0.8717538   0.8179411
Sepal.Width    -0.1175698   1.0000000   -0.4284401  -0.3661259
Petal.Length    0.8717538  -0.4284401    1.0000000   0.9628654
Petal.Width     0.8179411  -0.3661259    0.9628654   1.0000000

Корреляционный анализ – 2

library(corrplot)

iris %>% 
  select(1:4) %>% 
  cor() %>% 
  corrplot(., type = "upper", order = "hclust")

Корреляционный анализ – 3

library(corrplot)
library(nycflights13)

nycflights13::flights %>% 
  select(2:9) %>% 
  na.omit() %>% 
  cor() %>% 
  corrplot(., type = "upper", order = "hclust")

Кейс: датасет iris – структура данных

library(dplyr)
glimpse(iris)
Rows: 150
Columns: 5
$ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.…
$ Sepal.Width  <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.…
$ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5, 1.…
$ Petal.Width  <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2, 0.…
$ Species      <fct> setosa, setosa, setosa, setosa, setosa, setosa, setosa, s…
  • Petal – лепесток
  • Sepal – чашелистник

Кейс: датасет iris – вид распределения

library(ggplot2)
ggplot(iris,
       aes(x = Petal.Length, 
           y = Sepal.Length, 
           group = Species, 
           col = Species)) + 
  geom_point() + 
  geom_smooth(method = lm, se = FALSE) +
  scale_color_viridis_d(option = "plasma", end = .7)

Кейс: датасет iris – вид распределения

library(ggplot2)
ggplot(iris,
       aes(x = Petal.Length, 
           y = Petal.Width, 
           group = Species, 
           col = Species)) + 
  geom_point() + 
  geom_smooth(method = lm, se = FALSE) +
  scale_color_viridis_d(option = "plasma", end = .7)

Гипотеза

Petal.Length ~ Sepal.Length * Sepal.Width * Species

Справа – предикторы

Parsnip

library(tidymodels)

lm_mod <- parsnip::linear_reg() %>% 
  parsnip::set_engine("keras")

Keras

https://keras.io/

install.packages("keras")

pak::pkg_install("keras")

Создание модели

lm_mod <- linear_reg()

lm_fit <- 
  lm_mod %>% 
  fit(Petal.Length ~ Sepal.Length * Sepal.Width * Species, data = iris)
lm_fit
parsnip model object


Call:
stats::lm(formula = Petal.Length ~ Sepal.Length * Sepal.Width * 
    Species, data = data)

Coefficients:
                               (Intercept)  
                                   -3.9404  
                              Sepal.Length  
                                    1.0955  
                               Sepal.Width  
                                    1.3500  
                         Speciesversicolor  
                                   -6.2616  
                          Speciesvirginica  
                                    4.2902  
                  Sepal.Length:Sepal.Width  
                                   -0.2729  
            Sepal.Length:Speciesversicolor  
                                    1.2115  
             Sepal.Length:Speciesvirginica  
                                   -0.3127  
             Sepal.Width:Speciesversicolor  
                                    2.6566  
              Sepal.Width:Speciesvirginica  
                                   -1.2532  
Sepal.Length:Sepal.Width:Speciesversicolor  
                                   -0.3522  
 Sepal.Length:Sepal.Width:Speciesvirginica  
                                    0.2606  

Создание модели – 2

tidy(lm_fit)
# A tibble: 12 × 5
   term                                     estimate std.error statistic p.value
   <chr>                                       <dbl>     <dbl>     <dbl>   <dbl>
 1 (Intercept)                                -3.94      4.09     -0.962   0.338
 2 Sepal.Length                                1.10      0.827     1.32    0.188
 3 Sepal.Width                                 1.35      1.19      1.14    0.257
 4 Speciesversicolor                          -6.26      5.41     -1.16    0.249
 5 Speciesvirginica                            4.29      4.99      0.860   0.391
 6 Sepal.Length:Sepal.Width                   -0.273     0.234    -1.16    0.246
 7 Sepal.Length:Speciesversicolor              1.21      1.03      1.18    0.240
 8 Sepal.Length:Speciesvirginica              -0.313     0.928    -0.337   0.737
 9 Sepal.Width:Speciesversicolor               2.66      1.75      1.52    0.131
10 Sepal.Width:Speciesvirginica               -1.25      1.54     -0.813   0.418
11 Sepal.Length:Sepal.Width:Speciesversico…   -0.352     0.320    -1.10    0.273
12 Sepal.Length:Sepal.Width:Speciesvirgini…    0.261     0.275     0.948   0.345

Использование модели – 1

new_points_1 <- expand.grid(Sepal.Width = 4, 
                          Sepal.Length = 5,
                          Species = c("setosa", "virginica", "versicolor"))

mean_pred <- predict(lm_fit, new_data = new_points_1)
mean_pred
# A tibble: 3 × 1
  .pred
  <dbl>
1  1.48
2  4.40
3  4.86

Использование модели – 2

new_points_2 <- expand.grid(Sepal.Width = 2.5, 
                          Sepal.Length = 5,
                          Species = c("setosa", "virginica", "versicolor"))

mean_pred <- predict(lm_fit, new_data = new_points_2)
mean_pred
# A tibble: 3 × 1
  .pred
  <dbl>
1  1.50
2  4.35
3  3.54

Оценка модели

Доверительный интервал

conf_int_pred <- predict(lm_fit, 
                         new_data = new_points_2, 
                         type = "conf_int")
conf_int_pred
# A tibble: 3 × 2
  .pred_lower .pred_upper
        <dbl>       <dbl>
1        1.23        1.77
2        4.09        4.62
3        3.38        3.70

Визуализация

plot_data <- 
  new_points_2 %>% 
  bind_cols(mean_pred) %>% 
  bind_cols(conf_int_pred)

ggplot(plot_data, aes(x = Species)) + 
  geom_point(aes(y = .pred)) + 
  geom_errorbar(aes(ymin = .pred_lower, 
                    ymax = .pred_upper),
                width = .2) + 
  labs(y = "Petal.Length")

Вопросы ?

Спасибо за внимание!