## ----setup, include=FALSE------------------------------------------------
# General purpose packages (data handling, pretty plotting, ...)
library(tidyverse)
library(latex2exp) # Latex in ggplot2 labels
cbPalette <- c(
  "#999999", "#E69F00", "#56B4E9", "#009E73",
  "#F0E442", "#0072B2", "#D55E00", "#CC79A7") # colour-blind friendly palette

# Packages for actual computation
# library(ridge)         # Used for fitting ridge regression
# library(ElemStatLearn) # Package for the ESL book. Used for data and a
#                        # ridge regression implementation
# library(glmnet)        # Used for fitting ridge regression and lasso
# library(FNN)           # Used for its kNN implementation used for
#                        # visualisation purposes here (when showing which
#                        # regions map to which parts of the border of the
#                        # unit ball)


## ----penalisation-intuition, fig.width=4, fig.height=2.25, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache = TRUE----
create_lq_ball <- function(q, t = 1, h1 = 1e-2) {
  message(sprintf("q = %.1f, t = %.2f, h1 = %e", q, t, h1))
  if (is.infinite(q)) {
    xs <- seq(-t, t, by = h1)
    data_lq_ball <- tibble(
      x = c(
        xs, rep.int(-1, length(xs)),
        rev(xs), rep.int(1, length(xs))),
      y = c(
        rep.int(1, length(xs)), rev(xs),
        rep.int(-1, length(xs)), xs),
      seg = as.factor(rep(1:4, each = length(xs))),
      q = rep.int(q, 4 * length(xs)))
  } else {
    # Split up for computational accuracy
    split <- t / (2 ^ (1 / q))
    xs <- seq(-split, split, by = h1)
    ys <- (t ^ q - abs(xs) ^ q) ^ (1 / q)
    data_lq_ball <- tibble(
      x = c(rep.int(xs, 2), ys, -ys, c(0, 0, 1, -1)),
      y = c(ys, -ys, rep.int(xs, 2), c(1, -1, 0, 0)),
      seg = c(rep(1:4, each = length(xs)), 1:4),
      q = rep.int(q, 4 * length(xs) + 4)) %>%
      arrange(seg)
    # Depending on the segment we need to sort by x or y
    data_lq_ball[data_lq_ball$seg == 1,] <-
      data_lq_ball %>%
      filter(seg == 1) %>%
      arrange(seg, x)
    data_lq_ball[data_lq_ball$seg == 2,] <-
      data_lq_ball %>%
      filter(seg == 2) %>%
      arrange(seg, desc(x))
    data_lq_ball[data_lq_ball$seg == 3,] <-
      data_lq_ball %>%
      filter(seg == 3) %>%
      arrange(seg, desc(y))
    data_lq_ball[data_lq_ball$seg == 4,] <-
      data_lq_ball %>%
      filter(seg == 4) %>%
      arrange(seg, y)
  }

  data_lq_ball
}

data_lq_ball <- bind_rows(
  lapply(c(1, 2), function(q) create_lq_ball(q, h1 = 1e-3))) %>%
  mutate(seg = as.factor(seg), q = as.factor(q))

X <- scale(matrix(c(seq(1, 10.5, by = 0.5), sin(1:20)), ncol = 2))
beta_true <- c(0.4, 1.8)
y <- X %*% beta_true + rnorm(dim(X)[1], mean = 0, sd = 0.1)
beta_ols <- solve(t(X) %*% X) %*% t(X) %*% y
# Calculate RSS on a grid
h2 <- 0.005
xs_test <- seq(-3, 3, by = h2)
ys_test <- seq(-3, 3, by = h2)
test_grid <- expand.grid(xs_test, ys_test)
colnames(test_grid) <- c("x", "y")
data_ellipse <- bind_cols(
  as_tibble(test_grid),
  tibble(
    z = apply(test_grid, 1, function(p) sum((y - X %*% p) ^ 2))))

# Has to be (0, 1) for the lasso in this constructed example,
# but this is how the optimization could be done over a grid
# (for demonstration purposes only, not what you would do in a real example)
tmp_lasso <- data_ellipse %>%
  filter(rowSums(abs(test_grid)) <= 1) %>%
  arrange(z) %>%
  top_n(1, desc(z)) %>%
  as.numeric
beta_lasso <- tmp_lasso[1:2]
z_lasso <- tmp_lasso[3]
# Same thing but for ridge. Less obvious here
tmp_ridge <- data_ellipse %>%
  filter(rowSums(test_grid ^ 2) <= 1) %>%
  arrange(z) %>%
  top_n(1, desc(z)) %>%
  as.numeric
beta_ridge <- tmp_ridge[1:2]
z_ridge <- tmp_ridge[3]

p <- ggplot(mapping = aes(x, y)) +
  geom_segment(
    aes(xend = xend, yend = yend),
    data = tibble(
      x = c(-1.25, 0),
      y = c(0, -1.25),
      xend = c(2, 0),
      yend = c(0, 2.75)),
    size = 0.2,
    arrow = arrow(length = unit(0.1, "cm"))) +
  geom_text(
    aes(label = label),
    data = tibble(
      x = c(-0.25, 1.8, beta_ols[1] + 0.15),
      y = c(2.7, -0.25, beta_ols[2] - 0.15),
      label = c(
        TeX("$\\beta_1$", output = "character"),
        TeX("$\\beta_2$", output = "character"),
        TeX("$\\beta_{\\mathrm{OLS}}$", output = "character"))),
    parse = TRUE,
    size = 2.25) +
  geom_point(
    data = tibble(
      x = beta_ols[1],
      y = beta_ols[2]),
    colour = cbPalette[2], size = 1) +
  theme_void() +
  coord_fixed() +
  theme(
    plot.margin = margin(0,0,0,0,"cm"))

p1 <- p +
  geom_contour(
    aes(z = z),
    data = data_ellipse %>% filter(z <= z_lasso),
    alpha = 0.5,
    size = 0.3,
    bins = 5,
    colour = cbPalette[3]) +
  geom_polygon(
    data = data_lq_ball %>% filter(q == 1),
    fill = cbPalette[1],
    size = 0,
    alpha = 0.3) +
  geom_point(
    data = tibble(
      x = beta_lasso[1],
      y = beta_lasso[2]),
    colour = cbPalette[2], size = 1) +
  geom_text(
    aes(label = label),
    data = tibble(
      x = c(beta_lasso[1] + 0.25),
      y = c(beta_lasso[2] - 0.15),
      label = c(
        TeX("$\\beta_{\\mathrm{lasso}}$", output = "character"))),
    parse = TRUE,
    size = 2.25) +
  ggtitle("Lasso") +
  theme(plot.title = element_text(size = 8))

p2 <- p +
  geom_contour(
    aes(z = z),
    data = data_ellipse %>% filter(z <= z_ridge),
    alpha = 0.5,
    size = 0.3,
    bins = 5,
    colour = cbPalette[3]) +
  geom_polygon(
    data = data_lq_ball %>% filter(q == 2),
    fill = cbPalette[1],
    size = 0,
    alpha = 0.3) +
  geom_point(
    data = tibble(
      x = beta_ridge[1],
      y = beta_ridge[2]),
    colour = cbPalette[2], size = 1) +
  geom_text(
    aes(label = label),
    data = tibble(
      x = c(beta_ridge[1] + 0.15),
      y = c(beta_ridge[2] - 0.15),
      label = c(
        TeX("$\\beta_{\\mathrm{ridge}}$", output = "character"))),
    parse = TRUE,
    size = 2.25) +
  ggtitle("Ridge") +
  theme(plot.title = element_text(size = 8))

ggpubr::ggarrange(
  p1, p2, ncol = 2)


## ----penalisation-viz, fig.width=3.25, fig.height=3.25, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache = TRUE----
create_regions <- function(q, t = 1, h1 = 1e-2, h2 = 1e-2) {
  message(sprintf("q = %.1f, t = %.2f, h1 = %e, h2 = %e", q, t, h1, h2))
  if (is.infinite(q)) {
    xs <- seq(-t, t, by = h1)
    data_lq_ball <- tibble(
      x = c(
        xs, rep.int(-1, length(xs)),
        rev(xs), rep.int(1, length(xs))),
      y = c(
        rep.int(1, length(xs)), rev(xs),
        rep.int(-1, length(xs)), xs),
      seg = as.factor(rep(1:4, each = length(xs))),
      q = rep.int(q, 4 * length(xs)))
  } else {
    # Split up for computational accuracy
    split <- t / (2 ^ (1 / q))
    xs <- seq(-split, split, by = h1)
    ys <- (t ^ q - abs(xs) ^ q) ^ (1 / q)
    data_lq_ball <- tibble(
      x = c(rep.int(xs, 2), ys, -ys, c(0, 0, 1, -1)),
      y = c(ys, -ys, rep.int(xs, 2), c(1, -1, 0, 0)),
      seg = c(rep(1:4, each = length(xs)), 1:4),
      q = rep.int(q, 4 * length(xs) + 4)) %>%
      arrange(seg)
    # Depending on the segment we need to sort by x or y
    data_lq_ball[data_lq_ball$seg == 1,] <-
      data_lq_ball %>%
      filter(seg == 1) %>%
      arrange(seg, x)
    data_lq_ball[data_lq_ball$seg == 2,] <-
      data_lq_ball %>%
      filter(seg == 2) %>%
      arrange(seg, desc(x))
    data_lq_ball[data_lq_ball$seg == 3,] <-
      data_lq_ball %>%
      filter(seg == 3) %>%
      arrange(seg, desc(y))
    data_lq_ball[data_lq_ball$seg == 4,] <-
      data_lq_ball %>%
      filter(seg == 4) %>%
      arrange(seg, y)
  }

  lq_ball <- data_lq_ball %>%
    select(x, y) %>%
    as.matrix %>%
    unique

  xs_test <- seq(-2 * t, 2 * t, by = h2)
  ys_test <- seq(-2 * t, 2 * t, by = h2)
  test_grid <- expand.grid(xs_test, ys_test)
  colnames(test_grid) <- c("x", "y")
  # Filter out points inside the lq ball
  if (is.infinite(q)) {
    test_grid <- test_grid %>%
      filter(pmax(abs(x), abs(y)) > t)
  } else {
    test_grid <- test_grid %>%
      filter((abs(x) ^ q + abs(y) ^ q) ^ (1 / q) > t)
  }
  # Fake classes to be able to use a knn (k = 1) classifier to find the
  # closest point on lq ball
  cl <- rep.int(1, dim(lq_ball)[1])
  min_indices <- FNN::knnx.index(
    lq_ball, test_grid, 1, algorithm = "kd_tree")

  # Point of interest on lq ball
  poi1 <- c(x = 0, y = t)
  poi1_index <- which(
    apply(lq_ball, 1, function(x) {
      isTRUE(all.equal(x, poi1, tolerance = 2 * h1))
    }))

  poi2 <- c(x = t, y = 0)
  poi2_index <- which(
    apply(lq_ball, 1, function(x) {
      isTRUE(all.equal(x, poi2, tolerance = 2 * h1))
    }))

  if (is.infinite(q)) {
    poi3 <- c(x = t, y = t)
  } else {
    poi3 <- c(x = t / (2 ^ (1 / q)), y = t / (2 ^ (1 / q)))
  }
  poi3_index <- which(
    apply(lq_ball, 1, function(x) {
      isTRUE(all.equal(x, poi3, tolerance = 3 * h1))
    }))

  list(
    data_lq_ball = data_lq_ball,
    data_region = tibble(
      x = test_grid[,1],
      y = test_grid[,2],
      q = rep.int(q, dim(test_grid)[1]),
      reg1 = min_indices %in% poi1_index,
      reg2 = min_indices %in% poi2_index,
      reg3 = min_indices %in% poi3_index) %>%
      mutate(
        reg = case_when(
          reg1 ~ 1,
          reg2 ~ 2,
          reg3 ~ 3)) %>%
      select(x, y, reg, q) %>%
      filter(!is.na(reg)) %>%
      mutate(reg = as.factor(reg)) %>%
      group_by(reg) %>%
      slice(chull(x, y)))
}

regions <- lapply(
  c(0.7, 1, 2, Inf),
  function(q) create_regions(q, h1 = 1e-3, h2 = 1e-2))

data_lq_ball <- bind_rows(lapply(
  regions, function(l) l$data_lq_ball %>% mutate(seg = as.numeric(seg)))) %>%
  mutate(seg = as.factor(seg), q = as.factor(q))

data_region <- bind_rows(lapply(
  regions, function(l) l$data_region)) %>%
  mutate(q = as.factor(q))

ggplot(mapping = aes(x, y)) +
  geom_segment(
    aes(xend = xend, yend = yend),
    data = tibble(
      x = c(-2.25, 0),
      y = c(0, -2.25),
      xend = c(2.25, 0),
      yend = c(0, 2.25)),
    size = 0.2,
    arrow = arrow(length = unit(0.1, "cm"))) +
  geom_text(
    aes(label = label),
    data = tibble(
      x = c(-0.25, 2.2),
      y = c(2.2, -0.25),
      label = c(
        TeX("$\\beta_1$", output = "character"),
        TeX("$\\beta_2$", output = "character"))),
    parse = TRUE,
    size = 2.25) +
  geom_polygon(
    aes(fill = reg, colour = reg),
    data = data_region,
    alpha = 0.5,
    size = 0.3) +
  geom_polygon(
    data = data_lq_ball,
    fill = cbPalette[1],
    alpha = 0.3) +
  geom_point(
    aes(colour = reg),
    data = tibble(
      x = c(0, 1 / (2 ^ (1 / 0.7)), 1, 0, 0.5, 1, 0, 1 / sqrt(2), 1, 0, 1, 1),
      y = c(1, 1 / (2 ^ (1 / 0.7)), 0, 1, 0.5, 0, 1, 1 / sqrt(2), 0, 1, 1 ,0),
      reg = as.factor(rep(c(1, 3, 2), times = 4)),
      q = as.factor(rep(c(0.7, 1, 2, Inf), each = 3))), size = 1) +
  facet_wrap(~ q, labeller = label_both) +
  scale_colour_manual(values = cbPalette[-1], guide = FALSE) +
  scale_fill_manual(values = cbPalette[-1], guide = FALSE) +
  theme_void() +
  coord_fixed() +
  theme(
    plot.margin = margin(0,0,0,0,"cm"),
    strip.text = element_text(size = 7))


## ----ols-shrinkage-viz, fig.width=3, fig.height=1.6, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache = TRUE----
xs <- seq(-0.7, 0.7, by = 0.1)
lambda <- 0.2

data_plot <- tibble(
  x = rep.int(xs, 4),
  y = c(xs, xs / (1 + lambda), xs, sign(xs) * pmax(abs(xs) - lambda, 0)),
  trans =
    as.factor(
      rep(
        c("Untransformed", "Transformed", "Untransformed", "Transformed"),
        each = length(xs))),
  type = factor(rep(c("Ridge", "Lasso"), each = 2 * length(xs)),
    levels = c("Ridge", "Lasso"), ordered = TRUE))

ggplot(data_plot, mapping = aes(x, y)) +
  geom_segment(
    aes(xend = xend, yend = yend),
    data = tibble(
      x = c(-0.7, 0),
      y = c(0, -0.7),
      xend = c(0.7, 0),
      yend = c(0, 0.7)),
    arrow = arrow(length = unit(0.1, "cm")),
    size = 0.2) +
  geom_line(
    data = tibble(
      x = c(lambda, lambda),
      y = c(0, lambda),
      type = factor(c("Lasso"), levels = c("Ridge", "Lasso"))),
    linetype = "dashed") +
  geom_text(
    aes(label = label),
    data = tibble(
      x = c(lambda + 0.05),
      y = c(lambda / 2 + 0.03),
      label = TeX("$\\lambda$", output = "character"),
      type = factor(c("Lasso"), levels = c("Ridge", "Lasso"))),
    parse = TRUE,
    size = 2.5) +
  geom_line(aes(colour = trans, linetype = trans),
    linetype = rep(
      c("solid", "dashed", "solid", "dashed"),
      each = length(xs))) +
  scale_colour_manual(values = cbPalette[-1], guide = FALSE) +
  facet_wrap(~ type, ncol = 2) +
  theme_void() +
  coord_fixed()


## ----path-example, fig.width=5, fig.height=3, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache = TRUE----
# There are many packages implementing ridge regression. However,
# it can be instructive to implement it yourself
# By uncommenting the right lines below you can use the packages ridge,
# glmnet or ElemStatLearn instead
custom_ridge <- function(X, y, lambda) {
  n <- nrow(X)
  p <- ncol(X)

  # Prepare data
  beta0 <- mean(y)
  y <- y - mean(y)
  X <- scale(X)

  X_svd <- svd(X)
  D <- X_svd$d
  U <- X_svd$u
  V <- X_svd$v

  beta <- do.call(rbind, lapply(lambda, function(l) {
    matrix(
      V %*% diag(D / (D ^ 2 + l * rep.int(1, p))) %*% t(U) %*% y,
      nrow = 1)
  }))

  list(beta0 = beta0, beta = beta)
}

# Use a prostate cancer data set with p = 8 predictors, 1 response
# and n = 67 samples in a training data set. There is more data but
# using only these 67 samples makes the plots comparable to the
# ones in ESL (Fig 3.8, p. 65, Fig. 3.10, p. 70)
data_all <- as_tibble(ElemStatLearn::prostate) %>%
  filter(train) %>%
  mutate_at(vars(-lpsa), function(x) as.numeric(scale(x))) %>%
  mutate(lpsa = as.numeric(scale(lpsa, scale = FALSE))) %>%
  select(-train)

n <- nrow(data_all)
p <- ncol(data_all) - 1

log_lambda_seq <- seq(-6, 10, length.out = 40)

# Implement some simple cross-validation
folds <- caret::createFolds(data_all$lpsa, 10)
cv_err <- sapply(exp(log_lambda_seq), function(lambda) {
  sapply(folds, function(fold) {
    train <- data_all[-fold,]
    test <- data_all[fold,]

    # MLE standard deviation of the response
    sd_y <- sqrt(var(train$lpsa) * (n - 1) / n)

    # fit_ridge <- ridge::linearRidge(lpsa ~ ., train, lambda = lambda / n)
    # fit_ridge <- glmnet::glmnet(
    #   train[,1:8] %>% as.matrix,
    #   train$lpsa,
    #   alpha = 0)
    # fit_ridge <- ElemStatLearn::simple.ridge(
    #   train[,1:8] %>% as.matrix,
    #   train$lpsa,
    #   lambda = lambda)
    fit_ridge <- custom_ridge(train[,1:8] %>% as.matrix, train$lpsa, lambda)
    # Lasso
    fit_lasso <- glmnet::glmnet(
      train[,1:8] %>% as.matrix,
      train$lpsa,
      alpha = 1)

    # # For ridge::linearRidge
    # mse_ridge <- sum((test$lpsa - predict(fit_ridge, test)) ^ 2)
    # # For glmnet::glmnet
    # # glmnet uses a different scaling of the lambda values. The package
    # # is made for fitting elastic net models and the target function
    # # convenient for elastic nets includes lasso and ridge as special cases.
    # # However, scaling of lambda values is different in this case
    # mse_ridge <- sum((test$lpsa -
    #   predict(
    #     fit_ridge,
    #     test[,1:8] %>% as.matrix,
    #     s = sd_y * lambda / n,
    #     exact = TRUE,
    #     x = train[,1:8] %>% as.matrix,
    #     y = train$lpsa)) ^ 2)
    # # For ElemStatLearn::simple.ridge
    # mse_ridge <- sum((test$lpsa -
    #   as.numeric(
    #     cbind(rep.int(1, dim(test)[1]), (test[,1:8] %>% as.matrix)) %*%
    #     rbind(fit_ridge$beta0, fit_ridge$beta))) ^ 2)
    # For custom implementation
    mse_ridge <- sum((test$lpsa -
      as.numeric(
        cbind(rep.int(1, dim(test)[1]), (test[,1:8] %>% as.matrix)) %*%
        c(fit_ridge$beta0, drop(fit_ridge$beta)))) ^ 2)
    # For glmnet::glmnet lasso
    # See notes for predict.glmnet for ridge regression above
    mse_lasso <- sum((test$lpsa -
      predict(
        fit_lasso,
        test[,1:8] %>% as.matrix,
        s = sd_y * lambda / n,
        exact = TRUE,
        x = train[,1:8] %>% as.matrix,
        y = train$lpsa)) ^ 2)
    c(mse_ridge, mse_lasso)
  }) %>% rowMeans
}) %>% t

colnames(cv_err) <- c("ridge", "lasso")
log_lambda <- log_lambda_seq[apply(cv_err, 2, which.min)]
names(log_lambda) <- c("ridge", "lasso")

# The glmnet calls are best run once and then the coefficients can be obtained
# for specific lambdas
# # Ridge case
# fit_ridge <- glmnet::glmnet(
#   data_all[,1:8] %>% as.matrix,
#   data_all$lpsa,
#   alpha = 0)
# Lasso case
fit_lasso <- glmnet::glmnet(
  data_all[,1:8] %>% as.matrix,
  data_all$lpsa,
  alpha = 1)
# MLE standard deviation of the response
# Needed to properly scale all input lambdas for lasso
sd_all <- sqrt(var(data_all$lpsa) * (n - 1) / n)

# Create a coefficient matrix (one row per lambda) for each method
est_coef <- plyr::aaply(
  exp(log_lambda_seq),
  1,
  function(lambda) {
    # fit_ridge <- ridge::linearRidge(lpsa ~ ., data_all, lambda = lambda / n)
    # fit_ridge <- ElemStatLearn::simple.ridge(
    #     data_all[,1:8] %>% as.matrix,
    #     data_all$lpsa,
    #     lambda = lambda)
    fit_ridge <- custom_ridge(
      data_all[,1:8] %>% as.matrix, data_all$lpsa, lambda)
    cbind(
      # coef(fit_ridge)[-1],
      # coef(
      #   fit_ridge,
      #   s = sd_all * lambda / n,
      #   exact = TRUE,
      #   x = data_all[,1:8] %>% as.matrix,
      #   y = data_all$lpsa)[-1],
      # as.numeric(fit_ridge$beta),
      drop(fit_ridge$beta),
      coef(
        fit_lasso,
        s = sd_all * lambda / n,
        exact = TRUE,
        x = data_all[,1:8] %>% as.matrix,
        y = data_all$lpsa)[-1])
})

# Calculate the effective degrees of freedom for
d <- svd(data_all[,1:8] %>% as.matrix)$d
df <- sapply(exp(log_lambda_seq), function(l) {
  sum(d ^ 2 / (d ^ 2 + l))
})

# Shrinkage parameter for Lasso
# How large is the l1 norm of the coefficients proportional to the
# l1 norm of the least squares solution
ls_coef <- coef(lm(lpsa ~ ., data_all))[-1]
shrinkage <- rowSums(abs(est_coef[,,2])) / sum(abs(ls_coef))

data_plot <- bind_cols(
  tibble(
    alg = as.factor(rep(c("ridge", "lasso"), each = length(log_lambda_seq))),
    log_lambda = rep.int(
      log_lambda_seq, 2),
    x = c(df, shrinkage)),
  as_tibble(do.call(rbind, lapply(1:2, function(i) est_coef[,,i])))) %>%
  gather(yvar, yval, -alg, -log_lambda, -x) %>%
  mutate(
    yvar = as.factor(yvar))

# Used to visualise the selected lambda and corresponding df/shrinkage
data_cv_select <- tibble(
  alg = as.factor(rep(c("ridge", "lasso"), each = 2)),
  xvar = as.factor(c("df", "log_lambda", "shrinkage", "log_lambda")),
  xval = c(
    sum(d ^ 2 / (d ^ 2 + exp(log_lambda[1]))), log_lambda[1],
    shrinkage[which.min(cv_err[,2])], log_lambda[2]))

p1 <- ggplot(data_plot %>% filter(alg == "ridge"), aes(x = x, y = yval)) +
  geom_line(aes(group = yvar), size = 0.3) +
  geom_vline(
    aes(xintercept = xval),
    data = data_cv_select %>% filter(alg == "ridge", xvar == "df"),
    colour = "red", linetype = "dashed", size = 0.3) +
  geom_hline(
    aes(yintercept = 0),
    linetype = "dashed", size = 0.3) +
  scale_x_continuous("Effective degrees of freedom") +
  scale_y_continuous("Coefficient") +
  theme_minimal() +
  ggtitle("Ridge")

p2 <- ggplot(
  data_plot %>% filter(alg == "ridge"), aes(x = log_lambda, y = yval)) +
  geom_line(aes(group = yvar), size = 0.3) +
  geom_vline(
    aes(xintercept = xval),
    data = data_cv_select %>% filter(alg == "ridge", xvar == "log_lambda"),
    colour = "red", linetype = "dashed", size = 0.3) +
  geom_hline(
    aes(yintercept = 0),
    linetype = "dashed", size = 0.3) +
  scale_x_continuous(TeX("$\\log(\\lambda)$")) +
  scale_y_continuous("Coefficient") +
  theme_minimal()

p3 <- ggplot(data_plot %>% filter(alg == "lasso"), aes(x = x, y = yval)) +
  geom_line(aes(group = yvar), size = 0.3) +
  geom_vline(
    aes(xintercept = xval),
    data = data_cv_select %>% filter(alg == "lasso", xvar == "shrinkage"),
    colour = "red", linetype = "dashed", size = 0.3) +
  geom_hline(
    aes(yintercept = 0),
    linetype = "dashed", size = 0.3) +
  scale_x_continuous("Shrinkage") +
  scale_y_continuous("Coefficient") +
  theme_minimal() +
  ggtitle("Lasso")

p4 <- ggplot(
  data_plot %>% filter(alg == "lasso"), aes(x = log_lambda, y = yval)) +
  geom_line(aes(group = yvar), size = 0.3) +
  geom_vline(
    aes(xintercept = xval),
    data = data_cv_select %>% filter(alg == "lasso", xvar == "log_lambda"),
    colour = "red", linetype = "dashed", size = 0.3) +
  geom_hline(
    aes(yintercept = 0),
    linetype = "dashed", size = 0.3) +
  scale_x_continuous(TeX("$\\log(\\lambda)$")) +
  scale_y_continuous("Coefficient") +
  theme_minimal()

ggpubr::ggarrange(
  plotlist = lapply(
    list(p1, p3, p2, p4), function(p)
      p + theme(
        plot.title = element_text(size = 8),
        axis.title = element_text(size = 6),
        axis.text = element_text(size = 6))),
  ncol = 2, nrow = 2, heights = c(1.2, 1))