Fecha de publicación

12 de diciembre de 2024

Objetivo del manual

  • Entender como se construyen los árboles de decisión

  • Familiarizarse con los principales métodos de regresión y clasificación basados en ensamblado de árboles

  • Aprender a aplicar estos métodos en R

Paquetes a utilizar en este manual:

Código
# instalar/cargar paquetes
sketchy::load_packages(
  c("ggplot2", 
    "viridis", 
    "caret",
    "ISLR",
    "rpart",
    "rpart.plot",
    "tree",
    "randomForest",
    "xgboost"
   )
  )
Loading required package: caret
Loading required package: lattice

Los métodos basados en árboles estratifican o subdividen el espacio predictor en en una serie de regiones simples. Dado que el conjunto de reglas de división utilizadas para segmentar el espacio predictor puede representarse en un árbol, este tipo de enfoques se conocen como métodos de árboles de decisión. Los árboles de decisión se utilizan tanto para clasificación como para regresión. Su ventaja radica en la simplicidad de interpretación y en su capacidad para capturar interacciones no lineales entre predictores.

1 Regresión con árboles de decisión

Usaremos el conjunto de datos Hitters para predecir el salario de un jugador de béisbol basado en los años (“Years”, el número de años que ha jugado en las grandes ligas) y “Hits” (el número de hits que realizó en el año anterior). Primero eliminamos las observaciones que tienen valores de salario faltantes (“Salary”) y aplicamos una transformación logarítmica al salario para que su distribución tenga una forma más típica de campana.

Código
# cargar datos
data(Hitters)

# Eliminar observaciones con valores faltantes en Salary
Hitters <- na.omit(Hitters)

# Transformar la variable Salary al logaritmo natural
Hitters$LogSalary <- log(Hitters$Salary)

head(Hitters)
AtBat Hits HmRun Runs RBI Walks Years CAtBat CHits CHmRun CRuns CRBI CWalks League Division PutOuts Assists Errors Salary NewLeague LogSalary
-Alan Ashby 315 81 7 24 38 39 14 3449 835 69 321 414 375 N W 632 43 10 475.0 N 6.1633
-Alvin Davis 479 130 18 66 72 76 3 1624 457 63 224 266 263 A W 880 82 14 480.0 A 6.1738
-Andre Dawson 496 141 20 65 78 37 11 5628 1575 225 828 838 354 N E 200 11 3 500.0 N 6.2146
-Andres Galarraga 321 87 10 39 42 30 2 396 101 12 48 46 33 N E 805 40 4 91.5 N 4.5163
-Alfredo Griffin 594 169 4 74 51 35 11 4408 1133 19 501 336 194 A W 282 421 25 750.0 A 6.6201
-Al Newman 185 37 1 23 8 21 2 214 42 1 30 9 24 N E 76 127 7 70.0 A 4.2485

La siguiente figura muestra un árbol de regresión ajustado a estos datos. Consiste en una serie de reglas de división, comenzando desde la parte superior del árbol. La división superior asigna las observaciones con años < 4.5 a la rama izquierda. El salario predicho para estos jugadores se da por el valor promedio de respuesta para los jugadores en el conjunto de datos con años < 4.5. Para tales jugadores, el salario logarítmico medio es 5.107, por lo que hacemos una predicción de e^5.107 miles de dólares, es decir, $165,174, para estos jugadores. Los jugadores con Años >= 4.5 son asignados a la rama derecha, y luego ese grupo se subdivide aún más por “Hits”.

En general, el árbol divide a los jugadores en tres regiones del espacio de predictores: jugadores que han jugado durante cuatro años o menos, jugadores que han jugado durante cinco años o más y que hicieron menos de 118 hits el año pasado, y jugadores que han jugado durante cinco años o más y que hicieron al menos 118 hits el año pasado. Estas tres regiones se pueden escribir como R1 = {X | Años < 4.5}, R2 = {X | Años >= 4.5, Hits < 117.5}, y R3 = {X | Años >= 4.5, Hits >= 117.5}. La Figura 8.2 ilustra las regiones como función de los Años y Hits. Los salarios predichos para estos tres grupos son $1,000 × e^5.107 = $165,174, $1,000 × e^5.999 = $402,834, y $1,000 × e^6.740 = $845,346 respectivamente.

Siguiendo con la analogía del árbol, las regiones R1, R2 y R3 se conocen como nodos terminales o hojas del árbol. Los puntos a lo largo del árbol donde se divide el espacio del predictor se denominan nodos internos. En el árbol graficado mas arriba los dos nodos internos están indicados por el texto “Years < 4.5” y “Hits < 117.5”. Nos referimos a los segmentos del árbol que conectan los nodos como ramas.

Podríamos interpretar el árbol de regresión mostrado mas arriba de la siguiente manera: Los años son el factor más importante para determinar el salario, y los jugadores con menos experiencia ganan salarios más bajos que los jugadores más experimentados. Dado que un jugador tiene menos experiencia, parece que el número de hits que realizó en el año anterior juega un papel poco relevante en su salario. Pero entre los jugadores que han estado en las grandes ligas durante cinco años o más, el número de hits realizados en el año anterior sí afecta al salario, y los jugadores que hicieron más hits el año pasado tienden a tener salarios más altos.

Los valores predichos por el modelo dependen del estrato en el que caen los datos. Por ejemplo, si un jugador ha jugado durante 5 años y ha realizado 100 hits, entonces caerá en el estrato R2 y su salario predicho será el promedio de los salarios de los jugadores en R2.

1.1 Ajuste del modelo

El árbol que se genera para estos datos en realidad es mucho mas complejo y contiene mas divisiones de los datos en mas estratos. Los paquetes rpart y rpart.plot nos permite ajustar y visualizar estos modelos de una forma muy amigable:

Código
# Ajustar un árbol de regresión
arbol_regresion <- rpart::rpart(LogSalary ~ Years + Hits, data = Hitters)

# Visualizar el árbol
rpart.plot::rpart.plot(arbol_regresion, extra = 101)

La complejidad del árbol la podemos controlar mediante el parámetro cp (“complexity parameter”). Un valor más alto de cp resulta en un árbol más simple (con menos divisiones o reglas), mientras que un valor más bajo permite árboles más grandes y complejos. El objetivo es encontrar un equilibrio entre un árbol que sea suficientemente complejo para capturar patrones importantes en los datos, pero no tan complejo que incurra en sobreajuste. Por ejemplo, este es un árbol con un valor de cp de 0.05:

Código
arbol_cp.05 <- rpart(LogSalary ~ Years + Hits,
                     data = Hitters,
                     control = rpart.control(cp = 0.05))

rpart.plot(arbol_cp.05, extra = 101)

Este en cambio tiene un valor de cp de 0.001:

Código
arbol_cp.001 <- rpart(LogSalary ~ Years + Hits, data = Hitters, control = rpart.control(cp = 0.001))

rpart.plot(arbol_cp.001, extra = 101)

1.2 Validación cruzada

Afortunadamente la librería caret nos permite realizar validación cruzada para encontrar el mejor valor de cp para nuestro modelo. Podemos utilzar cualquiera de los métodos vistos en el manual de “sobreajuste y entrenamiento de modelos”. En este caso usamos el método de “dejar uno afuera” (LOOCV) para optimizar el valor de cp:

Código
# Configuración de validación cruzada
set.seed(42) # Para reproducibilidad
# Validación cruzada
train_control <- trainControl(method = "LOOCV")

# Entrenar el modelo de árbol usando caret
modelo_arbol <- train(
  LogSalary ~ Years + Hits,
  data = Hitters,
  method = "rpart", # Árbol de decisión
  trControl = train_control,
  tuneLength = 10 # Número de combinaciones de parámetros a probar
)         

# Resumen del modelo ajustado
print(modelo_arbol)
CART 

263 samples
  2 predictor

No pre-processing
Resampling: Leave-One-Out Cross-Validation 
Summary of sample sizes: 262, 262, 262, 262, 262, 262, ... 
Resampling results across tuning parameters:

  cp         RMSE     Rsquared  MAE    
  0.0042120  0.60271  0.547947  0.43523
  0.0046796  0.59658  0.555573  0.42811
  0.0085782  0.59364  0.557312  0.42568
  0.0096474  0.59458  0.554649  0.42884
  0.0110721  0.59198  0.558190  0.42732
  0.0169020  0.58909  0.562176  0.43504
  0.0183127  0.59250  0.555813  0.44211
  0.0444602  0.62855  0.501931  0.47958
  0.1145455  0.71442  0.361503  0.58884
  0.4445745  0.97762  0.016381  0.87625

RMSE was used to select the optimal model using the smallest value.
The final value used for the model was cp = 0.016902.

El mejor valor de cp encontrado por la validación cruzada es 0.00421 y un valor de RMSE de 0.60271. Podemos extraer y visualizar el árbol final así:

Código
# extraer mejor modelo
mejor_modelo <- modelo_arbol$finalModel

# Visualizar el árbol final
rpart.plot(mejor_modelo, extra = 101)

1.3 Comparación con modelos lineales

Con el siguiente gráfico podemos comparar el comportamiento de un modelo lineal con el de un árbol de regresión, usando un ejemplo de clasificación en dos dimensiones. En este ejemplo en el que la verdadera frontera de decisión es lineal, y está indicada por las regiones sombreadas. En la fila superior se ilustra como un enfoque clásico que asume una frontera lineal (izquierda) superará a un árbol de decisión que realiza divisiones paralelas a los ejes (derecha). En la fila inferior la verdadera frontera de decisión es no lineal. En este caso, un modelo lineal no puede capturar la verdadera frontera de decisión (izquierda), mientras que un árbol de decisión tiene éxito (derecha).

Tomado de Gareth et al 2013

1.4 Ejercicio 1

  1. Utilice un árbol de decisión de regresión para resolver el ejercicio 1 de la tarea 3.

  2. Realice la validación cruzada con el método de remuestreo de “boostrap” para entrenar el modelo.

2 Clasificación con árboles de decisión

Al igual que los modelos de regresión con árboles de decisión, los árboles de clasificación dividen el espacio predictor en una serie de regiones simples utilizando reglas de decisión, con el objetivo de predecir una clase o categoría como resultado. Este enfoque es similar al de los árboles de regresión, pero el criterio de división optimiza una métrica asociada con la clasificación, como la entropía o el índice de Gini.

Usaremos el conjunto de datos Carseats para predecir si las ventas son altas (High = “Yes”) o bajas (High = “No”) en función de varias características del conjunto de datos, como precio (Price) o publicidad (Advertising).

Código
# explorar
head(Carseats)
Sales CompPrice Income Advertising Population Price ShelveLoc Age Education Urban US
9.50 138 73 11 276 120 Bad 42 17 Yes Yes
11.22 111 48 16 260 83 Good 65 10 Yes Yes
10.06 113 35 10 269 80 Medium 59 12 Yes Yes
7.40 117 100 4 466 97 Medium 55 14 Yes Yes
4.15 141 64 3 340 128 Bad 38 13 Yes No
10.81 124 113 13 501 72 Bad 78 16 No Yes

Primero preprocesamos los datos para crear la variable objetivo binaria:

Código
# Cargar datos
data(Carseats)

# Crear variable de respuesta binaria
Carseats$High <- ifelse(Carseats$Sales > 8, "Yes", "No")
Carseats$High <- factor(Carseats$High)  # Convertir a factor

# Eliminar la variable 'Sales' para evitar colinealidad
Carseats$Sales <- NULL

2.1 Ajuste del modelo

Ajustaremos un árbol de clasificación simple utilizando la librería rpart y visualizaremos el árbol generado.

Código
# Ajustar un árbol de clasificación
arbol_clasificacion <- rpart(High ~ Price + Advertising + ShelveLoc + Age, 
                             data = Carseats, 
                             method = "class")

# Visualizar el árbol
rpart.plot(arbol_clasificacion, extra = 104, fallen.leaves = TRUE, shadow.col = "gray")

El árbol resultante muestra cómo los datos se dividen en regiones basadas en los predictores. Por ejemplo, el predictor más importante puede ser Price, donde precios más bajos están asociados con mayores ventas.

En el gráfico los nodos terminales indican la clase predicha (Yes o No) y el porcentaje de datos que pertenecen a esa clase. Las divisiones están basadas en reglas como Price >= 93.

2.2 Evaluación

Para evaluar el desempeño del modelo, dividiremos los datos en conjuntos de entrenamiento y prueba y generaremos una matriz de confusión:

Código
# Dividir en conjunto de entrenamiento y prueba
set.seed(42)
train_index <- sample(seq_len(nrow(Carseats)), size = 0.7 * nrow(Carseats))
train_data <- Carseats[train_index, ]
test_data <- Carseats[-train_index, ]

# Ajustar árbol con datos de entrenamiento
arbol_train <- rpart(High ~ Price + Advertising + ShelveLoc + Age, 
                     data = train_data, 
                     method = "class")

# Predicciones en el conjunto de prueba
predicciones <- predict(arbol_train, test_data, type = "class")

# Matriz de confusión
caret::confusionMatrix(test_data$High, predicciones)
Confusion Matrix and Statistics

          Reference
Prediction No Yes
       No  64  10
       Yes 16  30
                                        
               Accuracy : 0.783         
                 95% CI : (0.699, 0.853)
    No Information Rate : 0.667         
    P-Value [Acc > NIR] : 0.00351       
                                        
                  Kappa : 0.53          
                                        
 Mcnemar's Test P-Value : 0.32680       
                                        
            Sensitivity : 0.800         
            Specificity : 0.750         
         Pos Pred Value : 0.865         
         Neg Pred Value : 0.652         
             Prevalence : 0.667         
         Detection Rate : 0.533         
   Detection Prevalence : 0.617         
      Balanced Accuracy : 0.775         
                                        
       'Positive' Class : No            
                                        

2.3 Validación Cruzada

Para seleccionar el mejor valor de cp, usamos validación cruzada con la librería caret:

Código
# Configuración de validación cruzada
set.seed(42)
control <- trainControl(method = "cv", number = 10)

# Ajustar modelo con validación cruzada
modelo_cv <- train(
  High ~ Price + Advertising + ShelveLoc + Age, 
  data = train_data,
  method = "rpart",
  trControl = control
)

# Mostrar los resultados
print(modelo_cv)
CART 

280 samples
  4 predictor
  2 classes: 'No', 'Yes' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 252, 252, 252, 253, 251, 253, ... 
Resampling results across tuning parameters:

  cp        Accuracy  Kappa   
  0.050847  0.69182   0.356741
  0.080508  0.66680   0.296495
  0.254237  0.60375   0.089011

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.050847.
Código
# Visualizar el árbol final
rpart.plot(modelo_cv$finalModel, extra = 104)

El valor óptimo de cp es 0.05085. El árbol final representa la mejor combinación de simplicidad y precisión según la validación cruzada.

Ahora podemos evaluar el desempeño del mejor modelo:

Código
arbol_train_cv <-
  rpart(
    High ~ Price + Advertising + ShelveLoc + Age,
    data = Carseats,
    control = rpart.control(cp = modelo_cv$bestTune$cp),
    method = "class"
  )

# Predicciones en el conjunto de prueba
predicciones <- predict(arbol_train_cv, Carseats, type = "class")

# Matriz de confusión
caret::confusionMatrix(Carseats$High, predicciones)
Confusion Matrix and Statistics

          Reference
Prediction  No Yes
       No  203  33
       Yes  66  98
                                        
               Accuracy : 0.752         
                 95% CI : (0.707, 0.794)
    No Information Rate : 0.672         
    P-Value [Acc > NIR] : 0.000301      
                                        
                  Kappa : 0.472         
                                        
 Mcnemar's Test P-Value : 0.001299      
                                        
            Sensitivity : 0.755         
            Specificity : 0.748         
         Pos Pred Value : 0.860         
         Neg Pred Value : 0.598         
             Prevalence : 0.672         
         Detection Rate : 0.507         
   Detection Prevalence : 0.590         
      Balanced Accuracy : 0.751         
                                        
       'Positive' Class : No            
                                        

2.4 Ejercicio 2

  1. Ajusta un árbol de clasificación utilizando otras variables predictoras (por ejemplo, Income o Population).

  2. Realiza validación cruzada con un enfoque distinto (por ejemplo, validación LOOCV).

3 Métodos de ensamblado de árboles

Estos modelos son una extensión de árboles de decisión, los cuales representan una familia diversa y poderosa de técnicas en el aprendizaje estadístico y computacional. Al igual que los árboles de decisión, las extensiones de estos se basan en dividir los datos en subconjuntos más pequeños mediante reglas de decisión, lo que facilita la interpretación y el manejo de relaciones complejas entre variables. Sin embargo, los árboles simples a menudo carecen de precisión en comparación con métodos más sofisticados. Por ello, se han desarrollado diversas extensiones y variantes que optimizan su desempeño y amplían su aplicabilidad. Estas extensiones utilizan múltiples árboles de decisión para mejorar la precisión y robustez del modelo, y se conocen como métodos basados en el emsamblado de múltiples árboles.

Los métodos de emsamblado de árboles de decisión (como Bagging, Random Forest, GBM, XGBoost y Extra Trees) comparten varias características en su construcción y diseño del modelo. Estas características clave incluyen:

  1. Uso de múltiples árboles de decisión: Todos estos métodos generan una población de árboles de decisión, que se combinan (i.e. ensamblan) para mejorar el rendimiento predictivo en comparación con un único árbol de decisión.

  2. Aleatoriedad en la construcción de los árboles: Seleccionan subconjuntos aleatorios de datos para construir cada árbol (bootstrap sampling) e introducen aleatoriedad adicional en la selección de las variables que se evalúan en cada división.

  3. Uso de hiperparámetros: Los hiperparámetros en los modelos de aprendizaje estadístico son configuraciones o valores que controlan el comportamiento del modelo. Todos estos métodos requieren la configuración de hiperparámetros para controlar aspectos como:

    • Número de árboles.
    • Profundidad máxima de los árboles.
    • Fracción de muestras o variables utilizadas por árbol.

Ventajas de los métodos de ensamblado de árboles:

  1. Reducción de la varianza: combinan las predicciones de múltiples árboles para promediar (regresión) o votar (clasificación), reduciendo la varianza del modelo.

  2. Manejo de alta dimensionalidad: Son efectivos en conjuntos de datos con un gran número de variables predictoras, aprovechando su capacidad para identificar las más relevantes.

  3. Robustez frente al sobreajuste: La combinación de múltiples árboles hace que estos métodos sean menos propensos al sobreajuste en comparación con un solo árbol. Sin embargo, boosting puede sobreajustar si no se ajustan correctamente los hiperparámetros.

Estas características hacen que los métodos basados en poblaciones de árboles sean potentes y flexibles, siendo adecuados para problemas complejos en los que un modelo de árbol único podría fallar.

En este tutorial trabajaremos con 2 de estos métodos: Random Forest y XGBoost

3.1 Random Forest

El método de Random Forest es una extensión de los árboles de decisión que utiliza una combinación de múltiples árboles para mejorar la precisión y la robustez del modelo. Cada árbol se ajusta utilizando una muestra aleatoria con reemplazo (bootstrap) de los datos de entrenamiento, y en cada división del árbol se selecciona un subconjunto aleatorio de predictores. Esta estrategia reduce la correlación entre los árboles, mejorando el rendimiento general. A diferencia de un único árbol de decisión, los Random Forest son menos propensos a sobreajustar los datos.

Los principales hiperparámetros del Random Forest son los siguientes:

  1. Número de árboles (ntree):Determina cuántos árboles se generarán en el bosque. Más árboles generalmente mejoran la estabilidad y la capacidad de generalización, pero aumentan el tiempo de cálculo. Valor típico: 500 o 100.

  2. Número de predictores seleccionados por división (mtry): Define cuántas variables se seleccionan aleatoriamente de todas las disponibles para considerar en cada división de nodo. Valores más bajos aumentan la diversidad entre los árboles. Valores más altos hacen que los árboles sean más similares. Valor típico en clasificación: raíz cuadrada del número total de predictores. Valor típico en regresión: Total de predictores / 3.

  3. Tamaño mínimo de los nodos terminales (nodesize): Controla el número mínimo de observaciones en los nodos terminales. Valores pequeños permiten modelos más complejos. Valores grandes simplifican los árboles y previenen el sobreajuste. Valor típico: Clasificación: 1. Regresión: 5.

Estos hiperparámetros se pueden ajustar utilizando técnicas como búsqueda en cuadrícula (grid search) o búsqueda aleatoria (random search) en el paquete caret.

3.2 Ajuste de un modelo Random Forest

En esta sección, utilizaremos el conjunto de datos heart para clasificar si un paciente tiene enfermedad cardíaca (sick) basada en variables clínicas como el colesterol, la frecuencia cardíaca máxima, entre otras. El conjunto de datos incluye observaciones categóricas y numéricas, lo que lo hace ideal para ilustrar la flexibilidad de Random Forest.

Primero debemos leer los datos y darles el formato adecuado:

Código
heart <- read.csv("https://raw.githubusercontent.com/maRce10/aprendizaje_estadistico_2024/refs/heads/master/data/heart_data.csv")

# Nueva Columna
heart$sick <- ifelse(heart$sick == 0, "No", "Yes")

# hacerlo factor para q sea interpretado como categorico
heart$sick <- factor(heart$sick)

# Eliminar datos faltantes
heart <- na.omit(heart)

# revisar
head(heart)
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal sick
63 1 1 145 233 1 2 150 0 2.3 3 0 6 No
67 1 4 160 286 0 2 108 1 1.5 2 3 3 Yes
67 1 4 120 229 0 2 129 1 2.6 2 2 7 Yes
37 1 3 130 250 0 0 187 0 3.5 3 0 3 No
41 0 2 130 204 0 2 172 0 1.4 1 0 3 No
56 1 2 120 236 0 0 178 0 0.8 1 0 3 No

En este enlace podemos ver una descripción de los datos.

Ahora podemos ajustar un modelo Random Forest a los datos:

Código
# Ajustar modelo Random Forest
set.seed(42) # Para reproducibilidad
modelo_rf <- randomForest::randomForest(
  sick ~ ., 
  data = heart, 
  importance = TRUE,  # Para calcular la importancia de las variables
  ntree = 50000         # Número de árboles
)

# Resumen del modelo
print(modelo_rf)

Call:
 randomForest(formula = sick ~ ., data = heart, importance = TRUE,      ntree = 50000) 
               Type of random forest: classification
                     Number of trees: 50000
No. of variables tried at each split: 3

        OOB estimate of  error rate: 16.5%
Confusion matrix:
     No Yes class.error
No  140  20     0.12500
Yes  29 108     0.21168

El resultado muestra la tasa de error para cada clase, así como el error general (Out-of-Bag error, OOB). Este error se calcula al predecir observaciones no incluidas en la muestra bootstrap utilizada para construir cada árbol, lo que proporciona una estimación interna de la precisión del modelo.

3.3 Importancia de las variables

Una ventaja de los Random Forest es que pueden calcular automáticamente la importancia de cada predictor en la clasificación o predicción. Esto se mide mediante la reducción en la pureza del nodo (Gini index) o la precisión del modelo al permutar aleatoriamente los valores de cada predictor.

Código
# Importancia de las variables
importancia <- importance(modelo_rf)
print(importancia)
               No      Yes MeanDecreaseAccuracy MeanDecreaseGini
age       69.9604  47.3872              84.3381          12.9674
sex      110.7576  63.3700             125.7804           4.6577
cp       118.5269 172.5159             195.1832          17.8627
trestbps  16.2051   8.5533              17.8384          10.8094
chol       5.5427 -21.7215             -10.0971          11.6011
fbs       15.0411 -19.7773              -1.5309           1.3610
restecg    3.3443  14.8573              12.5993           2.9091
thalach   91.8999  58.5110             107.8348          17.1350
exang     43.4908  89.5156              94.4112           7.2705
oldpeak   95.4565 124.2312             155.5998          15.6410
slope     22.9566  85.8175              80.4838           6.4293
ca       210.9774 182.5273             255.6418          17.9004
thal     184.6603 165.7431             231.2559          18.8793
Código
# print ggplot gini importance
ggplot(importancia, aes(x = reorder(rownames(importancia), MeanDecreaseGini), y = MeanDecreaseGini)) +
  geom_bar(stat = "identity", fill = viridis(10)[3]) +
  coord_flip() +
  labs(x = "Variables", y = "Importancia (Mean Decrease Gini)")

La gráfica de importancia muestra las variables que contribuyen más al modelo. En este ejemplo, podemos observar que ca, thal y cp son especialmente importantes para predecir si un paciente tiene enfermedad cardíaca.

3.4 Validación cruzada

Podemos usar el paquete caret para realizar una búsqueda de hiperparámetros en Random Forest, optimizando el número de predictores considerados en cada división del árbol (mtry) mediante validación cruzada.

Código
library(caret)

# Configuración de validación cruzada
set.seed(42)
train_control <- trainControl(method = "cv", number = 10) # Validación cruzada 10-fold

# Entrenar modelo usando caret
modelo_rf_caret <- train(
  sick ~ ., 
  data = heart, 
  method = "rf", 
  trControl = train_control, 
  tuneLength = 10 # Número de combinaciones de parámetros a probar
)

# Resultados
print(modelo_rf_caret)
Random Forest 

297 samples
 13 predictor
  2 classes: 'No', 'Yes' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 267, 267, 267, 267, 268, 268, ... 
Resampling results across tuning parameters:

  mtry  Accuracy  Kappa  
   2    0.82517   0.64610
   3    0.81816   0.63295
   4    0.80161   0.60046
   5    0.80828   0.61204
   6    0.79126   0.57814
   8    0.79805   0.59108
   9    0.79816   0.59112
  10    0.78805   0.57168
  11    0.78805   0.57137
  13    0.78805   0.57168

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 2.
Código
# Hiperarametros del mejor modelo
modelo_rf_caret$bestTune
mtry
2

El resultado muestra el mejor valor de mtry encontrado por validación cruzada y la precisión asociada. Este valor puede ser utilizado para ajustar un modelo final.

Podemos evaluar el desempeño del modelo con los hiperparametros optimizados:

Código
# predecir el modelo en todos los datos
predicciones <- predict(modelo_rf_caret, heart)

# matriz de confusion
confusionMatrix(predicciones, heart$sick)
Confusion Matrix and Statistics

          Reference
Prediction  No Yes
       No  160   1
       Yes   0 136
                                    
               Accuracy : 0.997     
                 95% CI : (0.981, 1)
    No Information Rate : 0.539     
    P-Value [Acc > NIR] : <2e-16    
                                    
                  Kappa : 0.993     
                                    
 Mcnemar's Test P-Value : 1         
                                    
            Sensitivity : 1.000     
            Specificity : 0.993     
         Pos Pred Value : 0.994     
         Neg Pred Value : 1.000     
             Prevalence : 0.539     
         Detection Rate : 0.539     
   Detection Prevalence : 0.542     
      Balanced Accuracy : 0.996     
                                    
       'Positive' Class : No        
                                    

3.5 Visualización del error Out-of-Bag (OOB)

El Random Forest proporciona una estimación del error OOB durante el ajuste, lo que permite analizar cómo se comporta el modelo conforme aumentan los árboles.

Código
# Error OOB
a <- plot(
  modelo_rf, 
  main = "Error OOB en Random Forest",
  col = viridis(10)[3]
)

Este gráfico muestra cómo se estabiliza el error conforme se incrementa el número de árboles, lo que ayuda a determinar si se ha utilizado un número suficiente. El gráfico muestra 3 líneas: una para el error global y una para cada una de las categorías de la variable respuesta.

3.6 Ejercicio 3

  1. Utilice Random Forest para resolver el ejercicio 4 de la tarea 3.

  2. Realice la validación cruzada con el método de remuestreo repetido (“repeated CV”) para entrenar el modelo.

  3. Calcule la matriz de confusión, la exactitud y el área bajo la curva para el modelo del punto anterior.

3.7 Boosting

El método XGBoost (Extreme Gradient Boosting) es una implementación eficiente y optimizada de Boosting. A diferencia de los métodos de Random Forest, donde se genera un conjunto de árboles entrenados de manera independiente, XGBoost construye los árboles secuencialmente, corrigiendo los errores del modelo anterior en cada iteración. Cada árbol adicional se ajusta a los residuos (errores) del modelo anterior. Esta técnica tiende a ser más poderosa y flexible para resolver problemas de clasificación y regresión.

3.7.1 Hiperparámetros

Los principales hiperparámetros del XGBoost son los siguientes:

  1. Número de árboles (nrounds): Número total de árboles a generar. Más árboles pueden mejorar la precisión, pero también pueden aumentar el sobreajuste si no se regularizan adecuadamente. Valor típico: 100-1000 dependiendo del tamaño de los datos.

  2. Tasa de aprendizaje (eta): Controla cuánto contribuye cada árbol nuevo en el modelo. Un valor más bajo puede mejorar la generalización, pero requiere más árboles. Valores más altos pueden acelerar el entrenamiento pero aumentar el riesgo de sobreajuste. Valor típico: 0.01-0.3.

  3. Profundidad máxima de los árboles (max_depth): Limita la profundidad máxima de los árboles. Valores más altos permiten que el árbol aprenda más patrones complejos, pero también pueden causar sobreajuste. Valor típico: 3-10.

  4. Tamaño mínimo de los nodos (min_child_weight): Determina el número mínimo de muestras en un nodo para crear una nueva división. Valores más bajos permiten más divisiones, lo que puede llevar al sobreajuste. Valor típico: 1-10.

  5. Submuestreo (subsample): Define la proporción de muestras que se usarán para entrenar cada árbol. La submuestreo ayuda a reducir el sobreajuste, especialmente cuando se tienen muchos datos. Valor típico: 0.5-1.0.

  6. Submuestreo de características (colsample_bytree): Controla la fracción de características que se utilizan en cada árbol. Ayuda a reducir el sobreajuste al crear árboles más diversos. Valor típico: 0.5-1.0.

Estos hiperparámetros se pueden ajustar utilizando técnicas como búsqueda en cuadrícula o búsqueda aleatoria (como en caret).

3.8 Ajuste de un modelo XGBoost

En esta sección, utilizaremos nuevamente el conjunto de datos heart para clasificar si un paciente tiene enfermedad cardíaca (sick) basándonos en las mismas variables que usamos con Random Forest.

Podemos ajustar el modelo XGBoost así:

Código
# Crear el conjunto de datos de entrenamiento
X <- as.matrix(heart[, -ncol(heart)])  # Variables predictoras
y <- as.numeric(heart$sick) - 1  # Convertir a valores 0 y 1

# Configuración de los hiperparámetros
param <- list(
  objective = "binary:logistic", 
  eval_metric = "logloss", 
  max_depth = 6, 
  eta = 0.1, 
  subsample = 0.8, 
  colsample_bytree = 0.8
)

# Entrenamiento del modelo XGBoost
modelo_xgb <- xgboost(
  data = X, 
  label = y, 
  nrounds = 30, 
  params = param, 
  verbose = 0
)

# Ver el resultado del modelo
print(modelo_xgb)
##### xgb.Booster
raw: 58.7 Kb 
call:
  xgb.train(params = params, data = dtrain, nrounds = nrounds, 
    watchlist = watchlist, verbose = verbose, print_every_n = print_every_n, 
    early_stopping_rounds = early_stopping_rounds, maximize = maximize, 
    save_period = save_period, save_name = save_name, xgb_model = xgb_model, 
    callbacks = callbacks)
params (as set within xgb.train):
  objective = "binary:logistic", eval_metric = "logloss", max_depth = "6", eta = "0.1", subsample = "0.8", colsample_bytree = "0.8", validate_parameters = "TRUE"
xgb.attributes:
  niter
callbacks:
  cb.evaluation.log()
# of features: 13 
niter: 30
nfeatures : 13 
evaluation_log:
    iter train_logloss
       1       0.64292
       2       0.59494
---                   
      29       0.20190
      30       0.19719

3.9 Importancia de las variables

Al igual que con Random Forest, podemos ver la importancia de las variables utilizando el método xgb.importance():

Código
# Importancia de las variables
importancia_xgb <- xgb.importance(colnames(X), model = modelo_xgb)
print(importancia_xgb)
     Feature       Gain     Cover Frequency
 1:       cp 0.19424016 0.1287904 0.0676983
 2:     thal 0.17611240 0.1180942 0.0638298
 3:       ca 0.16140347 0.1505455 0.0851064
 4:      age 0.09613041 0.1147223 0.1779497
 5:  oldpeak 0.09473140 0.1190074 0.1237911
 6:  thalach 0.05883721 0.0744752 0.1083172
 7:    slope 0.04642379 0.0437283 0.0406190
 8:     chol 0.04580415 0.0686996 0.1083172
 9: trestbps 0.04114643 0.0724745 0.0967118
10:      sex 0.03351186 0.0375628 0.0406190
11:    exang 0.02874264 0.0375218 0.0328820
12:  restecg 0.02221044 0.0332299 0.0522244
13:      fbs 0.00070564 0.0011479 0.0019342
Código
# Visualizar la importancia
xgb.plot.importance(importance_matrix = importancia_xgb)

3.10 Validación cruzada

Podemos realizar una búsqueda de hiperparámetros usando validación cruzada con el paquete caret, al igual que con Random Forest. Sin embargo, en este caso, utilizaremos el método xgbTree para ajustar el modelo XGBoost. Ajustar este modelo es computacional intensivo y puede durar varios minutos en correr. Por lo tanto luego de correrlo lo guardamos como un archivo “RDS”. Estos archivos permiten guardar objetos de R de forma que se puedan leer nuevamente con facilidad manteniendo todos sus atributos:

Código
# Configuración de validación cruzada
train_control <- trainControl(method = "cv", number = 10)  # Validación cruzada 10-fold

# Entrenar el modelo con caret
modelo_xgb_caret <- train(
  sick ~ ., 
  data = heart, 
  method = "xgbTree", 
  trControl = train_control
)

saveRDS(modelo_xgb_caret, "modelo_xgb_caret.RDS")

Ahora podemos leer el modelo y ver los resultados:

Código
modelo_xgb_caret <- readRDS("modelo_xgb_caret.RDS")

# Resultados
print(modelo_xgb_caret)
eXtreme Gradient Boosting 

297 samples
 13 predictor
  2 classes: 'No', 'Yes' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 268, 268, 267, 268, 267, 267, ... 
Resampling results across tuning parameters:

  eta  max_depth  colsample_bytree  subsample  nrounds  Accuracy  Kappa  
  0.3  1          0.6               0.50        50      0.84839   0.69370
  0.3  1          0.6               0.50       100      0.82506   0.64612
  0.3  1          0.6               0.50       150      0.80494   0.60524
  0.3  1          0.6               0.75        50      0.83517   0.66752
  0.3  1          0.6               0.75       100      0.83851   0.67435
  0.3  1          0.6               0.75       150      0.82172   0.64039
  0.3  1          0.6               1.00        50      0.83839   0.67388
  0.3  1          0.6               1.00       100      0.83506   0.66728
  0.3  1          0.6               1.00       150      0.82517   0.64684
  0.3  1          0.8               0.50        50      0.82851   0.65405
  0.3  1          0.8               0.50       100      0.82184   0.64141
  0.3  1          0.8               0.50       150      0.81172   0.62081
  0.3  1          0.8               0.75        50      0.83172   0.66091
  0.3  1          0.8               0.75       100      0.83517   0.66702
  0.3  1          0.8               0.75       150      0.82506   0.64675
  0.3  1          0.8               1.00        50      0.83851   0.67387
  0.3  1          0.8               1.00       100      0.84184   0.68042
  0.3  1          0.8               1.00       150      0.83195   0.66002
  0.3  2          0.6               0.50        50      0.83529   0.66740
  0.3  2          0.6               0.50       100      0.81529   0.62712
  0.3  2          0.6               0.50       150      0.80483   0.60651
  0.3  2          0.6               0.75        50      0.83161   0.65904
  0.3  2          0.6               0.75       100      0.80471   0.60596
  0.3  2          0.6               0.75       150      0.80126   0.59690
  0.3  2          0.6               1.00        50      0.81529   0.62678
  0.3  2          0.6               1.00       100      0.80506   0.60646
  0.3  2          0.6               1.00       150      0.80184   0.60011
  0.3  2          0.8               0.50        50      0.81494   0.62671
  0.3  2          0.8               0.50       100      0.80483   0.60671
  0.3  2          0.8               0.50       150      0.78828   0.57256
  0.3  2          0.8               0.75        50      0.82506   0.64651
  0.3  2          0.8               0.75       100      0.80816   0.61262
  0.3  2          0.8               0.75       150      0.78161   0.56061
  0.3  2          0.8               1.00        50      0.81529   0.62689
  0.3  2          0.8               1.00       100      0.81874   0.63346
  0.3  2          0.8               1.00       150      0.80172   0.59824
  0.3  3          0.6               0.50        50      0.81862   0.63367
  0.3  3          0.6               0.50       100      0.81184   0.61973
  0.3  3          0.6               0.50       150      0.79517   0.58706
  0.3  3          0.6               0.75        50      0.80517   0.60843
  0.3  3          0.6               0.75       100      0.80828   0.61425
  0.3  3          0.6               0.75       150      0.79839   0.59358
  0.3  3          0.6               1.00        50      0.82851   0.65218
  0.3  3          0.6               1.00       100      0.81851   0.63249
  0.3  3          0.6               1.00       150      0.80816   0.61214
  0.3  3          0.8               0.50        50      0.81172   0.62131
  0.3  3          0.8               0.50       100      0.81149   0.62059
  0.3  3          0.8               0.50       150      0.79460   0.58541
  0.3  3          0.8               0.75        50      0.81184   0.62043
  0.3  3          0.8               0.75       100      0.81517   0.62679
  0.3  3          0.8               0.75       150      0.81207   0.62033
  0.3  3          0.8               1.00        50      0.81161   0.62012
  0.3  3          0.8               1.00       100      0.79483   0.58546
  0.3  3          0.8               1.00       150      0.79483   0.58576
  0.4  1          0.6               0.50        50      0.82529   0.64685
  0.4  1          0.6               0.50       100      0.82195   0.63885
  0.4  1          0.6               0.50       150      0.81839   0.63323
  0.4  1          0.6               0.75        50      0.83517   0.66684
  0.4  1          0.6               0.75       100      0.83517   0.66691
  0.4  1          0.6               0.75       150      0.80851   0.61285
  0.4  1          0.6               1.00        50      0.83862   0.67358
  0.4  1          0.6               1.00       100      0.83517   0.66733
  0.4  1          0.6               1.00       150      0.82851   0.65326
  0.4  1          0.8               0.50        50      0.83851   0.67357
  0.4  1          0.8               0.50       100      0.80161   0.60043
  0.4  1          0.8               0.50       150      0.80828   0.61231
  0.4  1          0.8               0.75        50      0.82529   0.64595
  0.4  1          0.8               0.75       100      0.83517   0.66648
  0.4  1          0.8               0.75       150      0.81149   0.61958
  0.4  1          0.8               1.00        50      0.82851   0.65352
  0.4  1          0.8               1.00       100      0.84195   0.68050
  0.4  1          0.8               1.00       150      0.82517   0.64665
  0.4  2          0.6               0.50        50      0.80138   0.59930
  0.4  2          0.6               0.50       100      0.78793   0.57223
  0.4  2          0.6               0.50       150      0.78816   0.57336
  0.4  2          0.6               0.75        50      0.81506   0.62710
  0.4  2          0.6               0.75       100      0.77793   0.55141
  0.4  2          0.6               0.75       150      0.79172   0.58011
  0.4  2          0.6               1.00        50      0.82494   0.64668
  0.4  2          0.6               1.00       100      0.80138   0.59885
  0.4  2          0.6               1.00       150      0.78483   0.56461
  0.4  2          0.8               0.50        50      0.80874   0.61430
  0.4  2          0.8               0.50       100      0.79839   0.59295
  0.4  2          0.8               0.50       150      0.79184   0.58005
  0.4  2          0.8               0.75        50      0.81506   0.62574
  0.4  2          0.8               0.75       100      0.78425   0.56350
  0.4  2          0.8               0.75       150      0.77793   0.55016
  0.4  2          0.8               1.00        50      0.81828   0.63205
  0.4  2          0.8               1.00       100      0.79138   0.57893
  0.4  2          0.8               1.00       150      0.78149   0.55882
  0.4  3          0.6               0.50        50      0.82195   0.64224
  0.4  3          0.6               0.50       100      0.79839   0.59455
  0.4  3          0.6               0.50       150      0.77138   0.53919
  0.4  3          0.6               0.75        50      0.82195   0.64039
  0.4  3          0.6               0.75       100      0.80494   0.60542
  0.4  3          0.6               0.75       150      0.79839   0.59277
  0.4  3          0.6               1.00        50      0.79149   0.57869
  0.4  3          0.6               1.00       100      0.79149   0.57899
  0.4  3          0.6               1.00       150      0.78816   0.57214
  0.4  3          0.8               0.50        50      0.80839   0.61236
  0.4  3          0.8               0.50       100      0.79851   0.59257
  0.4  3          0.8               0.50       150      0.81529   0.62681
  0.4  3          0.8               0.75        50      0.81184   0.62078
  0.4  3          0.8               0.75       100      0.80184   0.60005
  0.4  3          0.8               0.75       150      0.82195   0.64113
  0.4  3          0.8               1.00        50      0.81862   0.63326
  0.4  3          0.8               1.00       100      0.81851   0.63368
  0.4  3          0.8               1.00       150      0.81839   0.63304

Tuning parameter 'gamma' was held constant at a value of 0
Tuning
 parameter 'min_child_weight' was held constant at a value of 1
Accuracy was used to select the optimal model using the largest value.
The final values used for the model were nrounds = 50, max_depth = 1, eta
 = 0.3, gamma = 0, colsample_bytree = 0.6, min_child_weight = 1 and subsample
 = 0.5.

Estos son los valores optimizados de los hiperparámetros del modelo XGBoost. Podemos usar estos valores para ajustar un modelo final con los mejores hiperparámetros.

Código
# Mejor valor de max_depth y eta
modelo_xgb_caret$bestTune
nrounds max_depth eta gamma colsample_bytree min_child_weight subsample
50 1 0.3 0 0.6 1 0.5

Podemos evaluar el desempeño del modelo con los hiperparametros optimizados:

Código
# predecir el modelo en todos los datos
predicciones <- predict(modelo_xgb_caret, newdata = heart)

# matriz de confusion
confusionMatrix(predicciones, heart$sick)
Confusion Matrix and Statistics

          Reference
Prediction  No Yes
       No  146  24
       Yes  14 113
                                        
               Accuracy : 0.872         
                 95% CI : (0.829, 0.908)
    No Information Rate : 0.539         
    P-Value [Acc > NIR] : <2e-16        
                                        
                  Kappa : 0.741         
                                        
 Mcnemar's Test P-Value : 0.144         
                                        
            Sensitivity : 0.912         
            Specificity : 0.825         
         Pos Pred Value : 0.859         
         Neg Pred Value : 0.890         
             Prevalence : 0.539         
         Detection Rate : 0.492         
   Detection Prevalence : 0.572         
      Balanced Accuracy : 0.869         
                                        
       'Positive' Class : No            
                                        

3.11 Comparación de Random Forest y XGBoost

Característica XGBoost Random_Forest
Algoritmo base Gradient Boosting (modelos secuenciales) Bagging (modelos independientes)
Velocidad de entrenamiento Más lento debido a su naturaleza secuencial Más rápido gracias al entrenamiento paralelo
Desempeño en datos complejos Excelente para relaciones no lineales complejas Bueno, pero menos preciso en relaciones complejas
Robustez frente al ruido Puede ser sensible al ruido si no se regulariza adecuadamente Muy robusto frente al ruido
Regularización Incluye regularización L1 y L2 para evitar sobreajuste No incluye regularización explícita
Interpretabilidad Difícil de interpretar Más sencillo, especialmente con medidas de importancia
Optimización de hiperparámetros Requiere un ajuste cuidadoso para obtener buen desempeño Menos dependiente del ajuste de hiperparámetros
Escenarios recomendados Problemas grandes y complejos con alta dimensionalidad Exploración inicial de datos o problemas más simples
Uso común Competencias de machine learning, predicción precisa Modelos base y análisis preliminares
Bibliotecas xgboost, lightgbm randomForest, ranger

Referencias

Gareth, J., Daniela, W., Trevor, H., & Robert, T. (2013). An introduction to statistical learning: with applications in R. Spinger.

Información de la sesión

R version 4.3.2 (2023-10-31)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 22.04.2 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.10.0 
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.10.0

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
 [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C            
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

time zone: America/Costa_Rica
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] caret_6.0-94         lattice_0.20-45      xgboost_1.7.8.1     
 [4] randomForest_4.7-1.1 tree_1.0-43          rpart.plot_3.1.2    
 [7] rpart_4.1.16         ISLR_1.4             viridis_0.6.5       
[10] viridisLite_0.4.2    ggplot2_3.5.1        knitr_1.48          

loaded via a namespace (and not attached):
 [1] tidyselect_1.2.1     timeDate_4032.109    farver_2.1.2        
 [4] dplyr_1.1.4          fastmap_1.2.0        pROC_1.18.5         
 [7] digest_0.6.37        timechange_0.2.0     lifecycle_1.0.4     
[10] survival_3.2-13      magrittr_2.0.3       compiler_4.3.2      
[13] rlang_1.1.4          tools_4.3.2          utf8_1.2.4          
[16] yaml_2.3.10          data.table_1.14.10   labeling_0.4.3      
[19] htmlwidgets_1.6.4    plyr_1.8.9           withr_3.0.1         
[22] purrr_1.0.2          nnet_7.3-17          grid_4.3.2          
[25] stats4_4.3.2         fansi_1.0.6          e1071_1.7-16        
[28] colorspace_2.1-1     future_1.34.0        globals_0.16.3      
[31] scales_1.3.0         iterators_1.0.14     MASS_7.3-55         
[34] cli_3.6.3            rmarkdown_2.28       crayon_1.5.3        
[37] generics_0.1.3       remotes_2.5.0        rstudioapi_0.16.0   
[40] future.apply_1.11.2  reshape2_1.4.4       proxy_0.4-27        
[43] stringr_1.5.1        splines_4.3.2        parallel_4.3.2      
[46] vctrs_0.6.5          hardhat_1.3.0        Matrix_1.6-5        
[49] jsonlite_1.8.9       listenv_0.9.1        packrat_0.9.2       
[52] foreach_1.5.2        gower_1.0.1          recipes_1.0.9       
[55] glue_1.8.0           parallelly_1.38.0    codetools_0.2-18    
[58] xaringanExtra_0.8.0  lubridate_1.9.3      stringi_1.8.4       
[61] gtable_0.3.5         munsell_0.5.1        tibble_3.2.1        
[64] pillar_1.9.0         htmltools_0.5.8.1    ipred_0.9-14        
[67] lava_1.7.3           R6_2.5.1             evaluate_1.0.0      
[70] sketchy_1.0.3        class_7.3-20         Rcpp_1.0.13         
[73] gridExtra_2.3        nlme_3.1-155         prodlim_2023.08.28  
[76] xfun_0.47            pkgconfig_2.0.3      ModelMetrics_1.2.2.2