## ----setup, include=FALSE------------------------------------------------ # General purpose packages (data handling, pretty plotting, ...) library(tidyverse) library(latex2exp) # Latex in ggplot2 labels cbPalette <- c( "#999999", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7") # colour-blind friendly palette bgCol <- "#FAFAFA" theme_set(theme_minimal()) # Set a less verbose standard theme library(rgl) # 3D plotting knit_hooks$set(rgl = hook_rgl) # Packages for actual computation # library(glmnet) # Very comprehensive package for fitting ridge regression, # # lasso and the elastic net. Also more advanced versions # # like group lasso or adaptive lasso can be fit with this # library(pamr) # Implementation of nearest shrunken centroids # ## ----penalisation-intuition, fig.width=4, fig.height=2.25, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache = TRUE---- create_lq_ball <- function(q, t = 1, h1 = 1e-2) { message(sprintf("q = %.1f, t = %.2f, h1 = %e", q, t, h1)) if (is.infinite(q)) { xs <- seq(-t, t, by = h1) data_lq_ball <- tibble( x = c( xs, rep.int(-1, length(xs)), rev(xs), rep.int(1, length(xs))), y = c( rep.int(1, length(xs)), rev(xs), rep.int(-1, length(xs)), xs), seg = as.factor(rep(1:4, each = length(xs))), q = rep.int(q, 4 * length(xs))) } else { # Split up for computational accuracy split <- t / (2 ^ (1 / q)) xs <- seq(-split, split, by = h1) ys <- (t ^ q - abs(xs) ^ q) ^ (1 / q) data_lq_ball <- tibble( x = c(rep.int(xs, 2), ys, -ys, c(0, 0, 1, -1)), y = c(ys, -ys, rep.int(xs, 2), c(1, -1, 0, 0)), seg = c(rep(1:4, each = length(xs)), 1:4), q = rep.int(q, 4 * length(xs) + 4)) %>% arrange(seg) # Depending on the segment we need to sort by x or y data_lq_ball[data_lq_ball$seg == 1,] <- data_lq_ball %>% filter(seg == 1) %>% arrange(seg, x) data_lq_ball[data_lq_ball$seg == 2,] <- data_lq_ball %>% filter(seg == 2) %>% arrange(seg, desc(x)) data_lq_ball[data_lq_ball$seg == 3,] <- data_lq_ball %>% filter(seg == 3) %>% arrange(seg, desc(y)) data_lq_ball[data_lq_ball$seg == 4,] <- data_lq_ball %>% filter(seg == 4) %>% arrange(seg, y) } data_lq_ball } data_lq_ball <- bind_rows( lapply(c(1, 2), function(q) create_lq_ball(q, h1 = 1e-3))) %>% mutate(seg = as.factor(seg), q = as.factor(q)) X <- scale(matrix(c(seq(1, 10.5, by = 0.5), sin(1:20)), ncol = 2)) beta_true <- c(0.4, 1.8) y <- X %*% beta_true + rnorm(dim(X)[1], mean = 0, sd = 0.1) beta_ols <- solve(t(X) %*% X) %*% t(X) %*% y # Calculate RSS on a grid h2 <- 0.005 xs_test <- seq(-3, 3, by = h2) ys_test <- seq(-3, 3, by = h2) test_grid <- expand.grid(xs_test, ys_test) colnames(test_grid) <- c("x", "y") data_ellipse <- bind_cols( as_tibble(test_grid), tibble( z = apply(test_grid, 1, function(p) sum((y - X %*% p) ^ 2)))) # Has to be (0, 1) for the lasso in this constructed example, # but this is how the optimization could be done over a grid # (for demonstration purposes only, not what you would do in a real example) tmp_lasso <- data_ellipse %>% filter(rowSums(abs(test_grid)) <= 1) %>% arrange(z) %>% top_n(1, desc(z)) %>% as.numeric beta_lasso <- tmp_lasso[1:2] z_lasso <- tmp_lasso[3] # Same thing but for ridge. Less obvious here tmp_ridge <- data_ellipse %>% filter(rowSums(test_grid ^ 2) <= 1) %>% arrange(z) %>% top_n(1, desc(z)) %>% as.numeric beta_ridge <- tmp_ridge[1:2] z_ridge <- tmp_ridge[3] p <- ggplot(mapping = aes(x, y)) + geom_segment( aes(xend = xend, yend = yend), data = tibble( x = c(-1.25, 0), y = c(0, -1.25), xend = c(2, 0), yend = c(0, 2.75)), size = 0.2, arrow = arrow(length = unit(0.1, "cm"))) + geom_text( aes(label = label), data = tibble( x = c(-0.25, 1.8, beta_ols[1] + 0.15), y = c(2.7, -0.25, beta_ols[2] - 0.15), label = c( TeX("$\\beta_1$", output = "character"), TeX("$\\beta_2$", output = "character"), TeX("$\\beta_{\\mathrm{OLS}}$", output = "character"))), parse = TRUE, size = 2.25) + geom_point( data = tibble( x = beta_ols[1], y = beta_ols[2]), colour = cbPalette[2], size = 1) + theme_void() + coord_fixed() + theme( plot.margin = margin(0,0,0,0,"cm")) p1 <- p + geom_contour( aes(z = z), data = data_ellipse %>% filter(z <= z_lasso), alpha = 0.5, size = 0.3, bins = 5, colour = cbPalette[3]) + geom_polygon( data = data_lq_ball %>% filter(q == 1), fill = cbPalette[1], size = 0, alpha = 0.3) + geom_point( data = tibble( x = beta_lasso[1], y = beta_lasso[2]), colour = cbPalette[2], size = 1) + geom_text( aes(label = label), data = tibble( x = c(beta_lasso[1] + 0.25), y = c(beta_lasso[2] - 0.15), label = c( TeX("$\\beta_{\\mathrm{lasso}}$", output = "character"))), parse = TRUE, size = 2.25) + ggtitle("Lasso") + theme(plot.title = element_text(size = 8)) p2 <- p + geom_contour( aes(z = z), data = data_ellipse %>% filter(z <= z_ridge), alpha = 0.5, size = 0.3, bins = 5, colour = cbPalette[3]) + geom_polygon( data = data_lq_ball %>% filter(q == 2), fill = cbPalette[1], size = 0, alpha = 0.3) + geom_point( data = tibble( x = beta_ridge[1], y = beta_ridge[2]), colour = cbPalette[2], size = 1) + geom_text( aes(label = label), data = tibble( x = c(beta_ridge[1] + 0.15), y = c(beta_ridge[2] - 0.15), label = c( TeX("$\\beta_{\\mathrm{ridge}}$", output = "character"))), parse = TRUE, size = 2.25) + ggtitle("Ridge") + theme(plot.title = element_text(size = 8)) ggpubr::ggarrange( p1, p2, ncol = 2) ## ----penalisation-viz, fig.width=3.25, fig.height=3.25, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache = TRUE---- create_regions <- function(q, t = 1, h1 = 1e-2, h2 = 1e-2) { message(sprintf("q = %.1f, t = %.2f, h1 = %e, h2 = %e", q, t, h1, h2)) if (is.infinite(q)) { xs <- seq(-t, t, by = h1) data_lq_ball <- tibble( x = c( xs, rep.int(-1, length(xs)), rev(xs), rep.int(1, length(xs))), y = c( rep.int(1, length(xs)), rev(xs), rep.int(-1, length(xs)), xs), seg = as.factor(rep(1:4, each = length(xs))), q = rep.int(q, 4 * length(xs))) } else { # Split up for computational accuracy split <- t / (2 ^ (1 / q)) xs <- seq(-split, split, by = h1) ys <- (t ^ q - abs(xs) ^ q) ^ (1 / q) data_lq_ball <- tibble( x = c(rep.int(xs, 2), ys, -ys, c(0, 0, 1, -1)), y = c(ys, -ys, rep.int(xs, 2), c(1, -1, 0, 0)), seg = c(rep(1:4, each = length(xs)), 1:4), q = rep.int(q, 4 * length(xs) + 4)) %>% arrange(seg) # Depending on the segment we need to sort by x or y data_lq_ball[data_lq_ball$seg == 1,] <- data_lq_ball %>% filter(seg == 1) %>% arrange(seg, x) data_lq_ball[data_lq_ball$seg == 2,] <- data_lq_ball %>% filter(seg == 2) %>% arrange(seg, desc(x)) data_lq_ball[data_lq_ball$seg == 3,] <- data_lq_ball %>% filter(seg == 3) %>% arrange(seg, desc(y)) data_lq_ball[data_lq_ball$seg == 4,] <- data_lq_ball %>% filter(seg == 4) %>% arrange(seg, y) } lq_ball <- data_lq_ball %>% select(x, y) %>% as.matrix %>% unique xs_test <- seq(-2 * t, 2 * t, by = h2) ys_test <- seq(-2 * t, 2 * t, by = h2) test_grid <- expand.grid(xs_test, ys_test) colnames(test_grid) <- c("x", "y") # Filter out points inside the lq ball if (is.infinite(q)) { test_grid <- test_grid %>% filter(pmax(abs(x), abs(y)) > t) } else { test_grid <- test_grid %>% filter((abs(x) ^ q + abs(y) ^ q) ^ (1 / q) > t) } # Fake classes to be able to use a knn (k = 1) classifier to find the # closest point on lq ball cl <- rep.int(1, dim(lq_ball)[1]) min_indices <- FNN::knnx.index( lq_ball, test_grid, 1, algorithm = "kd_tree") # Point of interest on lq ball poi1 <- c(x = 0, y = t) poi1_index <- which( apply(lq_ball, 1, function(x) { isTRUE(all.equal(x, poi1, tolerance = 2 * h1)) })) poi2 <- c(x = t, y = 0) poi2_index <- which( apply(lq_ball, 1, function(x) { isTRUE(all.equal(x, poi2, tolerance = 2 * h1)) })) if (is.infinite(q)) { poi3 <- c(x = t, y = t) } else { poi3 <- c(x = t / (2 ^ (1 / q)), y = t / (2 ^ (1 / q))) } poi3_index <- which( apply(lq_ball, 1, function(x) { isTRUE(all.equal(x, poi3, tolerance = 3 * h1)) })) list( data_lq_ball = data_lq_ball, data_region = tibble( x = test_grid[,1], y = test_grid[,2], q = rep.int(q, dim(test_grid)[1]), reg1 = min_indices %in% poi1_index, reg2 = min_indices %in% poi2_index, reg3 = min_indices %in% poi3_index) %>% mutate( reg = case_when( reg1 ~ 1, reg2 ~ 2, reg3 ~ 3)) %>% select(x, y, reg, q) %>% filter(!is.na(reg)) %>% mutate(reg = as.factor(reg)) %>% group_by(reg) %>% slice(chull(x, y))) } regions <- lapply( c(0.7, 1, 2, Inf), function(q) create_regions(q, h1 = 1e-3, h2 = 1e-2)) data_lq_ball <- bind_rows(lapply( regions, function(l) l$data_lq_ball %>% mutate(seg = as.numeric(seg)))) %>% mutate(seg = as.factor(seg), q = as.factor(q)) data_region <- bind_rows(lapply( regions, function(l) l$data_region)) %>% mutate(q = as.factor(q)) ggplot(mapping = aes(x, y)) + geom_segment( aes(xend = xend, yend = yend), data = tibble( x = c(-2.25, 0), y = c(0, -2.25), xend = c(2.25, 0), yend = c(0, 2.25)), size = 0.2, arrow = arrow(length = unit(0.1, "cm"))) + geom_text( aes(label = label), data = tibble( x = c(-0.25, 2.2), y = c(2.2, -0.25), label = c( TeX("$\\beta_1$", output = "character"), TeX("$\\beta_2$", output = "character"))), parse = TRUE, size = 2.25) + geom_polygon( aes(fill = reg, colour = reg), data = data_region, alpha = 0.5, size = 0.3) + geom_polygon( data = data_lq_ball, fill = cbPalette[1], alpha = 0.3) + geom_point( aes(colour = reg), data = tibble( x = c(0, 1 / (2 ^ (1 / 0.7)), 1, 0, 0.5, 1, 0, 1 / sqrt(2), 1, 0, 1, 1), y = c(1, 1 / (2 ^ (1 / 0.7)), 0, 1, 0.5, 0, 1, 1 / sqrt(2), 0, 1, 1 ,0), reg = as.factor(rep(c(1, 3, 2), times = 4)), q = as.factor(rep(c(0.7, 1, 2, Inf), each = 3))), size = 1) + facet_wrap(~ q, labeller = label_both) + scale_colour_manual(values = cbPalette[-1], guide = FALSE) + scale_fill_manual(values = cbPalette[-1], guide = FALSE) + theme_void() + coord_fixed() + theme( plot.margin = margin(0,0,0,0,"cm"), strip.text = element_text(size = 7)) ## ----path-example, fig.width=5, fig.height=3, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache = TRUE---- # There are many packages implementing ridge regression. However, # it can be instructive to implement it yourself # By uncommenting the right lines below you can use the packages ridge, # glmnet or ElemStatLearn instead custom_ridge <- function(X, y, lambda) { n <- nrow(X) p <- ncol(X) # Prepare data beta0 <- mean(y) y <- y - mean(y) X <- scale(X) X_svd <- svd(X) D <- X_svd$d U <- X_svd$u V <- X_svd$v beta <- do.call(rbind, lapply(lambda, function(l) { matrix( V %*% diag(D / (D ^ 2 + l * rep.int(1, p))) %*% t(U) %*% y, nrow = 1) })) list(beta0 = beta0, beta = beta) } # Use a prostate cancer data set with p = 8 predictors, 1 response # and n = 67 samples in a training data set. There is more data but # using only these 67 samples makes the plots comparable to the # ones in ESL (Fig 3.8, p. 65, Fig. 3.10, p. 70) data_all <- as_tibble(ElemStatLearn::prostate) %>% filter(train) %>% mutate_at(vars(-lpsa), function(x) as.numeric(scale(x))) %>% mutate(lpsa = as.numeric(scale(lpsa, scale = FALSE))) %>% select(-train) n <- nrow(data_all) p <- ncol(data_all) - 1 log_lambda_seq <- seq(-6, 10, length.out = 40) # Implement some simple cross-validation folds <- caret::createFolds(data_all$lpsa, 10) cv_err <- sapply(exp(log_lambda_seq), function(lambda) { sapply(folds, function(fold) { train <- data_all[-fold,] test <- data_all[fold,] # MLE standard deviation of the response sd_y <- sqrt(var(train$lpsa) * (n - 1) / n) # fit_ridge <- ridge::linearRidge(lpsa ~ ., train, lambda = lambda / n) # fit_ridge <- glmnet::glmnet( # train[,1:8] %>% as.matrix, # train$lpsa, # alpha = 0) # fit_ridge <- ElemStatLearn::simple.ridge( # train[,1:8] %>% as.matrix, # train$lpsa, # lambda = lambda) fit_ridge <- custom_ridge(train[,1:8] %>% as.matrix, train$lpsa, lambda) # Lasso fit_lasso <- glmnet::glmnet( train[,1:8] %>% as.matrix, train$lpsa, alpha = 1) # # For ridge::linearRidge # mse_ridge <- sum((test$lpsa - predict(fit_ridge, test)) ^ 2) # # For glmnet::glmnet # # glmnet uses a different scaling of the lambda values. The package # # is made for fitting elastic net models and the target function # # convenient for elastic nets includes lasso and ridge as special cases. # # However, scaling of lambda values is different in this case # mse_ridge <- sum((test$lpsa - # predict( # fit_ridge, # test[,1:8] %>% as.matrix, # s = sd_y * lambda / n, # exact = TRUE, # x = train[,1:8] %>% as.matrix, # y = train$lpsa)) ^ 2) # # For ElemStatLearn::simple.ridge # mse_ridge <- sum((test$lpsa - # as.numeric( # cbind(rep.int(1, dim(test)[1]), (test[,1:8] %>% as.matrix)) %*% # rbind(fit_ridge$beta0, fit_ridge$beta))) ^ 2) # For custom implementation mse_ridge <- sum((test$lpsa - as.numeric( cbind(rep.int(1, dim(test)[1]), (test[,1:8] %>% as.matrix)) %*% c(fit_ridge$beta0, drop(fit_ridge$beta)))) ^ 2) # For glmnet::glmnet lasso # See notes for predict.glmnet for ridge regression above mse_lasso <- sum((test$lpsa - predict( fit_lasso, test[,1:8] %>% as.matrix, s = sd_y * lambda / n, exact = TRUE, x = train[,1:8] %>% as.matrix, y = train$lpsa)) ^ 2) c(mse_ridge, mse_lasso) }) %>% rowMeans }) %>% t colnames(cv_err) <- c("ridge", "lasso") log_lambda <- log_lambda_seq[apply(cv_err, 2, which.min)] names(log_lambda) <- c("ridge", "lasso") # The glmnet calls are best run once and then the coefficients can be obtained # for specific lambdas # # Ridge case # fit_ridge <- glmnet::glmnet( # data_all[,1:8] %>% as.matrix, # data_all$lpsa, # alpha = 0) # Lasso case fit_lasso <- glmnet::glmnet( data_all[,1:8] %>% as.matrix, data_all$lpsa, alpha = 1) # MLE standard deviation of the response # Needed to properly scale all input lambdas for lasso sd_all <- sqrt(var(data_all$lpsa) * (n - 1) / n) # Create a coefficient matrix (one row per lambda) for each method est_coef <- plyr::aaply( exp(log_lambda_seq), 1, function(lambda) { # fit_ridge <- ridge::linearRidge(lpsa ~ ., data_all, lambda = lambda / n) # fit_ridge <- ElemStatLearn::simple.ridge( # data_all[,1:8] %>% as.matrix, # data_all$lpsa, # lambda = lambda) fit_ridge <- custom_ridge( data_all[,1:8] %>% as.matrix, data_all$lpsa, lambda) cbind( # coef(fit_ridge)[-1], # coef( # fit_ridge, # s = sd_all * lambda / n, # exact = TRUE, # x = data_all[,1:8] %>% as.matrix, # y = data_all$lpsa)[-1], # as.numeric(fit_ridge$beta), drop(fit_ridge$beta), coef( fit_lasso, s = sd_all * lambda / n, exact = TRUE, x = data_all[,1:8] %>% as.matrix, y = data_all$lpsa)[-1]) }) # Calculate the effective degrees of freedom for d <- svd(data_all[,1:8] %>% as.matrix)$d df <- sapply(exp(log_lambda_seq), function(l) { sum(d ^ 2 / (d ^ 2 + l)) }) # Shrinkage parameter for Lasso # How large is the l1 norm of the coefficients proportional to the # l1 norm of the least squares solution ls_coef <- coef(lm(lpsa ~ ., data_all))[-1] shrinkage <- rowSums(abs(est_coef[,,2])) / sum(abs(ls_coef)) data_plot <- bind_cols( tibble( alg = as.factor(rep(c("ridge", "lasso"), each = length(log_lambda_seq))), log_lambda = rep.int( log_lambda_seq, 2), x = c(df, shrinkage)), as_tibble(do.call(rbind, lapply(1:2, function(i) est_coef[,,i])))) %>% gather(yvar, yval, -alg, -log_lambda, -x) %>% mutate( yvar = as.factor(yvar)) # Used to visualise the selected lambda and corresponding df/shrinkage data_cv_select <- tibble( alg = as.factor(rep(c("ridge", "lasso"), each = 2)), xvar = as.factor(c("df", "log_lambda", "shrinkage", "log_lambda")), xval = c( sum(d ^ 2 / (d ^ 2 + exp(log_lambda[1]))), log_lambda[1], shrinkage[which.min(cv_err[,2])], log_lambda[2])) p1 <- ggplot(data_plot %>% filter(alg == "ridge"), aes(x = x, y = yval)) + geom_line(aes(group = yvar), size = 0.3) + geom_vline( aes(xintercept = xval), data = data_cv_select %>% filter(alg == "ridge", xvar == "df"), colour = "red", linetype = "dashed", size = 0.3) + geom_hline( aes(yintercept = 0), linetype = "dashed", size = 0.3) + scale_x_continuous("Effective degrees of freedom") + scale_y_continuous("Coefficient") + theme_minimal() + ggtitle("Ridge") p2 <- ggplot( data_plot %>% filter(alg == "ridge"), aes(x = log_lambda, y = yval)) + geom_line(aes(group = yvar), size = 0.3) + geom_vline( aes(xintercept = xval), data = data_cv_select %>% filter(alg == "ridge", xvar == "log_lambda"), colour = "red", linetype = "dashed", size = 0.3) + geom_hline( aes(yintercept = 0), linetype = "dashed", size = 0.3) + scale_x_continuous(TeX("$\\log(\\lambda)$")) + scale_y_continuous("Coefficient") + theme_minimal() p3 <- ggplot(data_plot %>% filter(alg == "lasso"), aes(x = x, y = yval)) + geom_line(aes(group = yvar), size = 0.3) + geom_vline( aes(xintercept = xval), data = data_cv_select %>% filter(alg == "lasso", xvar == "shrinkage"), colour = "red", linetype = "dashed", size = 0.3) + geom_hline( aes(yintercept = 0), linetype = "dashed", size = 0.3) + scale_x_continuous("Shrinkage") + scale_y_continuous("Coefficient") + theme_minimal() + ggtitle("Lasso") p4 <- ggplot( data_plot %>% filter(alg == "lasso"), aes(x = log_lambda, y = yval)) + geom_line(aes(group = yvar), size = 0.3) + geom_vline( aes(xintercept = xval), data = data_cv_select %>% filter(alg == "lasso", xvar == "log_lambda"), colour = "red", linetype = "dashed", size = 0.3) + geom_hline( aes(yintercept = 0), linetype = "dashed", size = 0.3) + scale_x_continuous(TeX("$\\log(\\lambda)$")) + scale_y_continuous("Coefficient") + theme_minimal() ggpubr::ggarrange( plotlist = lapply( list(p1, p3, p2, p4), function(p) p + theme( plot.title = element_text(size = 8), axis.title = element_text(size = 6), axis.text = element_text(size = 6))), ncol = 2, nrow = 2, heights = c(1.2, 1)) ## ----microarray-avg-expr-prep, echo=FALSE, cache=TRUE-------------------- # Load a microarray dataset for visualisation purposes data(khan, package = "pamr") # The implementation below is very much un-optimized and should not be # used for any real project. This is only to gain more understanding of the # algorithm # Clean up and convert to more tidy objects classes <- factor( t(khan)[-c(1:2), 1] %>% unname, levels = c("BL", "EWS", "NB", "RMS"), ordered = TRUE) data_train <- t(khan)[-c(1:2), -1] %>% as_tibble %>% mutate_all(as.numeric) %>% as.matrix #### Perform nearest shrunken centroid nsc <- function(X_train, y_train, X_test, y_test, lambda) { # X_train and X_test have samples as columns # Adapted after `pamr` package N <- length(y_train) nk <- table(y_train) K <- length(nk) Y <- model.matrix(~ y_train - 1) dimnames(Y) <- list(NULL, levels(y_train)) centroids <- scale(X_train %*% Y, center = FALSE, scale = nk) centroid_overall <- drop(X_train %*% rep(1 / N, N)) stddev <- drop(sqrt( ((X_train - centroids %*% t(Y)) ^ 2) %*% rep.int(1 / (N - K), N))) offset <- quantile(stddev, 0.5) stddev <- stddev + offset # Pre-soft-thresholding scaling delta <- (centroids - centroid_overall) / stddev delta <- scale(delta, center = FALSE, sqrt(1 / nk - 1 / N)) # Perform soft-thresholding dif <- abs(delta) - lambda delta_shrunk <- sign(delta) * dif * (dif > 0) # Post-soft-thresholding scaling delta_shrunk <- scale( delta_shrunk, center = FALSE, 1 / sqrt(1 / nk - 1 / N)) # Predict da_mixed <- t((X_test - centroid_overall) / stddev) %*% delta_shrunk da_quad <- drop(rep(1, nrow(delta_shrunk)) %*% (delta_shrunk ^ 2)) / 2 - log(nk / N) names(da_quad) <- NULL class_indices <- apply( scale(da_mixed, da_quad, scale = FALSE), 1, which.max) list( mcerror = sum(levels(y_train)[class_indices] != y_test) / length(y_test), centroids_shrunk = centroid_overall + delta_shrunk * stddev) } # Determine lambda with cross-validation folds <- caret::createFolds(classes, k = 5) lambdas <- seq(0, 7, length.out = 20) cv_misclass <- sapply(lambdas, function(lambda) { sapply(folds, function(fold) { X_train <- t(data_train)[,-fold] y_train <- classes[-fold] X_test <- t(data_train)[,fold] y_test <- classes[fold] nsc(X_train, y_train, X_test, y_test, lambda)$mcerror }) }) %>% t %>% rowMeans %>% cbind(lambdas, error = .) lambda_min <- as_tibble(cv_misclass) %>% arrange(error, desc(lambdas)) %>% .[[1,1]] centroids_shrunk <- nsc( t(data_train), classes, t(data_train), classes, lambda_min)$centroids_shrunk N <- length(classes) centroid_overall <- drop(t(data_train) %*% rep(1 / N, N)) # Visualise the unpenalised centroids against the shrunken centroids centroids_plot1 <- bind_cols( gene = 1:dim(centroids_shrunk)[1], type = rep.int(2, dim(centroids_shrunk)[1]), as_tibble(centroids_shrunk - centroid_overall)) %>% gather(class, expr, -type, -gene) %>% mutate(class = factor(class, levels = levels(classes), ordered = TRUE)) nk <- table(classes) Y <- model.matrix(~ classes - 1) dimnames(Y) <- list(NULL, levels(classes)) centroids <- scale(t(data_train) %*% Y, center = FALSE, scale = nk) centroids_plot2 <- bind_cols( gene = 1:dim(centroids)[1], type = rep.int(1, dim(centroids)[1]), as_tibble(centroids - centroid_overall)) %>% gather(class, expr, -type, -gene) %>% mutate(class = factor(class, levels = levels(classes), ordered = TRUE)) ## ----microarray-avg-expr-cv, fig.width=3, fig.height=1.5, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache=TRUE---- ggplot(as_tibble(cv_misclass), aes(lambdas, error)) + geom_line(colour = "grey") + geom_vline( aes(xintercept = lambda_min), linetype = "dashed", colour = "red") + geom_point(size=1) + scale_x_continuous(TeX("$\\lambda$")) + scale_y_continuous("Misclassification rate") + theme( axis.title = element_text(size = 8), axis.text = element_text(size = 8)) ## ----microarray-avg-expr-centroids, fig.width=4, fig.height=3, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache=TRUE---- ggplot(bind_rows(centroids_plot2, centroids_plot1), aes(y = expr, x = gene)) + geom_line(aes(colour = as.factor(type)), size = 0.4) + scale_x_continuous("Gene", expand = ggplot2::expand_scale(0.01)) + scale_y_continuous( "Average Expression", expand = c(0, 0), breaks = c(-2, 0, 2)) + scale_colour_manual(values = c("grey", "red"), guide = FALSE) + facet_wrap(~ class, nrow = 4) + coord_cartesian(ylim = c(-1.8, 1.8)) + theme( panel.grid = element_blank(), axis.title = element_text(size = 7), axis.text = element_text(size = 7), strip.text = element_text(size = 7)) ## ----lasso-probs, fig.width=4, fig.height=1.8, fig.align="center", echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache = TRUE---- set.seed(30206932) # The problem occurs when predictors are close to perfectly correlated Sigma <- rbind( cbind(matrix(rep.int(1, 9), ncol = 3), matrix(rep.int(0, 9), ncol = 3)), cbind(matrix(rep.int(0, 9), ncol = 3), matrix(rep.int(1, 9), ncol = 3))) n <- 100 # Predictors which are perfectly correlated X <- MASS::mvrnorm(n, mu = rep.int(0, 6), Sigma = Sigma) colnames(X) <- sprintf("x%d", 1:6) beta_true <- c(3, 0, 0, 0, -1.5, 0) y <- as.numeric(X %*% beta_true + 2 * rnorm(n)) # Add some slight noise to the predictors so they are not exactly the # same, which is unrealistic either way X <- X + MASS::mvrnorm(n, mu = rep.int(0, 6), Sigma = diag(rep.int(1, 6)) / 25) # Data preparation X <- scale(X) beta0 <- mean(y) y <- y - mean(y) fit_lasso <- glmnet::glmnet( X, y, alpha = 1, standardize = FALSE, intercept = FALSE) alpha <- 0.3 fit_elnet <- glmnet::glmnet( X, y, alpha = alpha, standardize = FALSE, intercept = FALSE) log_lambda <- seq(-6, 7, by = 0.1) sd_y <- sqrt(var(y) * (n - 1) / n) # The lasso in glmnet is slightly different # to the general notation and lambdas need # to be scaled beta_lasso <- as.matrix(coef( fit_lasso, s = sd_y * exp(log_lambda) / n, exact = TRUE, x = X, y = y))[-1,] %>% t %>% as_tibble beta_elnet <- as.matrix(coef( fit_elnet, s = sd_y * exp(log_lambda) / n, exact = TRUE, x = X, y = y))[-1,] %>% t %>% as_tibble beta_ols <- coef(lm(y ~ X))[-1] %>% unname beta_ols_l1 <- sum(abs(beta_ols)) beta_ols_l21 <- 0.5 * (1 - alpha) * sum(beta_ols ^ 2) + alpha * sum(abs(beta_ols)) bind_cols( type = factor(rep( c("Lasso", TeX("Elastic net ($\\alpha = 0.3$)", output = "character")), each = length(log_lambda)), ordered = TRUE), log_lambda = rep.int(log_lambda, 2), shrinkage = c( apply(beta_lasso, 1, function(b) sum(abs(b)) / beta_ols_l1), apply(beta_elnet, 1, function(b) (0.5 * (1 - alpha) * sum(b ^ 2) + alpha * sum(abs(b))) / beta_ols_l21)), bind_rows(beta_lasso, beta_elnet)) %>% gather(var, val, -type, -log_lambda, -shrinkage) %>% mutate( var = as.factor(var), group = as.factor(case_when( var %in% c("x1", "x2", "x3") ~ 0, TRUE ~ 1))) %>% ggplot(aes(shrinkage, val, group = var, colour = group)) + geom_hline(aes(yintercept = 0), linetype = "dashed", size = 0.2) + geom_line() + facet_wrap(~ type, labeller = label_parsed) + scale_x_continuous("Shrinkage") + scale_y_continuous("Coefficient") + scale_colour_manual(values = cbPalette[-1], guide = FALSE) + theme( panel.spacing.x = unit(1, "lines"), axis.title = element_text(size = 8), axis.text = element_text(size = 8)) ## ----mesh-preparation, echo=FALSE---------------------------------------- # %%%%%%%% Derivations for the elastic net ball # %% For each variable separately # 0.5 * (1 - alpha) * x ^ 2 + alpha * abs(x) <= 1 => # (1 - alpha) * x ^ 2 + 2 * alpha * abs(x) <= 2 => # (abs(x) + alpha / (1 - alpha)) ^ 2 - alpha ^ 2 / (1 - alpha) ^ 2 <= # 2 / (1 - alpha) => # (abs(x) + alpha / (1 - alpha)) ^ 2 <= # 2 / (1 - alpha) + alpha ^ 2 / (1 - alpha) ^ 2 => # abs(x) <= # (sqrt(2 * (1 - alpha) + alpha ^ 2) - alpha) / (1 - alpha) # %% In 2D, given one variable fixed # 0.5 * (1 - alpha) * (x ^ 2 + y ^ 2) + alpha * (abs(x) + abs(y)) <= 1 => # (1 - alpha) * (x ^ 2 + y ^ 2) + 2 * alpha * (abs(x) + abs(y)) <= 2 => # (1 - alpha) * y ^ 2 + 2 * alpha * abs(y) <= # 2 - ((1 - alpha) * x ^ 2 + 2 * alpha * abs(x)) => # (1 - alpha) * (y ^ 2 + 2 * alpha / (1 - alpha) * abs(y) + alpha ^ 2 / (1 - alpha) ^ 2) - alpha ^ 2 / (1 - alpha) <= # 2 - (1 - alpha) * (x ^ 2 + 2 * alpha * abs(x) + alpha ^ 2 / (1 - alpha) ^ 2) + alpha ^ 2 / (1 - alpha) => # (1 - alpha) * (abs(y) + alpha / (1 - alpha)) ^ 2 <= # 2 - (1 - alpha) * (abs(x) + alpha / (1 - alpha)) ^ 2 + 2 * alpha ^ 2 / (1 - alpha) => # (abs(y) + alpha / (1 - alpha)) ^ 2 <= # 2 / (1 - alpha) - 1 / (1 - alpha) ^ 2 * ((1 - alpha) * abs(x) + alpha) ^ 2 + 2 * alpha ^ 2 / (1 - alpha) ^ 2 => # (abs(y) + alpha / (1 - alpha)) ^ 2 <= # ((1 - alpha) * 2 - ((1 - alpha) * abs(x) + alpha) ^ 2 + 2 * alpha ^ 2) / (1 - alpha) ^ 2 => # abs(y) + alpha / (1 - alpha) <= # sqrt((1 - alpha) * 2 - ((1 - alpha) * abs(x) + alpha) ^ 2 + 2 * alpha ^ 2) / (1 - alpha) => # abs(y) <= # (sqrt((1 - alpha) * 2 - ((1 - alpha) * abs(x) + alpha) ^ 2 + 2 * alpha ^ 2) - alpha) / (1 - alpha) # %% In 3D, given two variables fixed # 0.5 * (1 - alpha) * (x ^ 2 + y ^ 2 + z ^ 2) + alpha * (abs(x) + abs(y) + abs(z)) = 1 => # (1 - alpha) * z ^ 2 + 2 * alpha * abs(z) = 2 - (1 - alpha) * (x ^ 2 + y ^ 2 ) - 2 * alpha * (abs(x) + abs(y)) => # (1 - alpha) * (abs(z) + alpha / (1 - alpha)) ^ 2 - alpha ^ 2 / (1 - alpha) = 2 - (1 - alpha) * (x ^ 2 + y ^ 2 ) - 2 * alpha * (abs(x) + abs(y)) => # (1 - alpha) * (abs(z) + alpha / (1 - alpha)) ^ 2 = 2 - (1 - alpha) * (x ^ 2 + y ^ 2 ) - 2 * alpha * (abs(x) + abs(y)) + alpha ^ 2 / (1 - alpha) => # (abs(z) + alpha / (1 - alpha)) ^ 2 = 2 / (1 - alpha) - (x ^ 2 + y ^ 2 ) - 2 * alpha / (1 - alpha) * (abs(x) + abs(y)) + alpha ^ 2 / (1 - alpha) ^ 2 => # abs(z) + alpha / (1 - alpha) = # sqrt(2 / (1 - alpha) - (x ^ 2 + y ^ 2 ) - 2 * alpha / (1 - alpha) * (abs(x) + abs(y)) + alpha ^ 2 / (1 - alpha) ^ 2) => # abs(z) = # sqrt(2 / (1 - alpha) - (x ^ 2 + y ^ 2 ) - 2 * alpha / (1 - alpha) * (abs(x) + abs(y)) + alpha ^ 2 / (1 - alpha) ^ 2) - alpha / (1 - alpha) # %%%%% Group lasso derivation # sqrt(x ^ 2 + z ^ 2) + abs(y) <= 1 => # %% In 1D: # -1 <= x, y, z <= 1 # %% In 2D for fixed x: # abs(x) + abs(y) <= 1 => # abs(y) <= 1 - abs(x) # %% In 3D for fixed (x, y) # sqrt(x ^ 2 + z ^ 2) + abs(y) = 1 => # sqrt(x ^ 2 + z ^ 2) = 1 - abs(y) => # x ^ 2 + z ^ 2 = (1 - abs(y)) ^ 2 => # z ^ 2 = (1 - abs(y)) ^ 2 - x ^ 2 => # abs(z) = sqrt((1 - abs(y)) ^ 2 - x ^ 2) ##### Actual calculations # For lasso # abs(x) + abs(y) + abs(z) = 1 => # -1 <= x, y, z <= 1 # Generate grid -1 <= x <= 1 # Generate grid -(1 - abs(x)) <= y <= 1 - abs(x) for each x # Set z1 = (1 - abs(x) - abs(y)) and z2 = -(1 - abs(x) - abs(y)) pen_lasso <- function(h) { # Create x coordinates xs <- seq(-1, 1, by = h) # Create corresponding y coordinates yss <- lapply(xs, function(x) { ybnd <- 1 - abs(x) seq(-ybnd, ybnd, by = h) }) # Create a matrix of (x, y) coordinates XY <- do.call(rbind, purrr::map2(xs, yss, function(x, ys) { do.call(rbind, lapply(ys, function(y) c(x, y))) })) # Create the corresponding z values and return as a matrix XYZ <- do.call(rbind, lapply(split(XY, seq(nrow(XY))), function(xy) { z <- (1 - abs(xy[1]) - abs(xy[2])) matrix(c(xy[1], xy[1], xy[2], xy[2], z, -z), ncol = 3) })) colnames(XYZ) <- c("x", "y", "z") unique(XYZ) } # For elnet # 0.5 * (1 - alpha) * (x ^ 2 + y ^ 2 + z ^ 2) + # alpha * (abs(x) + abs(y) + abs(z)) = 1 => # Generate grid # abs(x) <= (sqrt(2 * (1 - alpha) + alpha ^ 2) - alpha) / (1 - alpha) # Generate grid for each x # abs(y) <= # (sqrt((1 - alpha) * 2 - ((1 - alpha) * abs(x) - alpha) ^ 2 + 2 * alpha ^ 2) - alpha) / (1 - alpha) # Set # abs(z) = # sqrt(2 / (1 - alpha) - (x ^ 2 + y ^ 2 ) - 2 * alpha / (1 - alpha) * (abs(x) + abs(y)) + alpha ^ 2 / (1 - alpha) ^ 2) - alpha / (1 - alpha) pen_elnet <- function(h, alpha) { # Create x coordinates xbnd <- abs((sqrt(2 * (1 - alpha) + alpha ^ 2) - alpha) / (1 - alpha)) xs <- seq(-xbnd, xbnd, by = h) # Create matrix for x coordinates X <- matrix(rep.int(xs, length(xs)), ncol = length(xs)) # Create corresponding y coordinates yss <- lapply(xs, function(x) { ybnd <- abs((sqrt((1 - alpha) * 2 - ((1 - alpha) * abs(x) + alpha) ^ 2 + 2 * alpha ^ 2) - alpha) / (1 - alpha)) seq(-ybnd, ybnd, by = h) }) # Create a matrix of (x, y) coordinates XY <- do.call(rbind, purrr::map2(xs, yss, function(x, ys) { do.call(rbind, lapply(ys, function(y) c(x, y))) })) # Create the corresponding z values and return as a matrix XYZ <- do.call(rbind, lapply(split(XY, seq(nrow(XY))), function(xy) { z <- sqrt(2 / (1 - alpha) - (xy[1] ^ 2 + xy[2] ^ 2) - 2 * alpha / (1 - alpha) * (abs(xy[1]) + abs(xy[2])) + alpha ^ 2 / (1 - alpha) ^ 2) - alpha / (1 - alpha) matrix(c(xy[1], xy[1], xy[2], xy[2], z, -z), ncol = 3) })) colnames(XYZ) <- c("x", "y", "z") unique(XYZ) } pen_group_lasso <- function(h) { # Create x coordinates xs <- seq(-1, 1, by = h) # Create corresponding y coordinates yss <- lapply(xs, function(x) { ybnd <- 1 - abs(x) seq(-ybnd, ybnd, by = h) }) # Create a matrix of (x, y) coordinates XY <- do.call(rbind, purrr::map2(xs, yss, function(x, ys) { do.call(rbind, lapply(ys, function(y) c(x, y))) })) # Create the corresponding z values and return as a matrix XYZ <- do.call(rbind, lapply(split(XY, seq(nrow(XY))), function(xy) { z <- sqrt((1 - abs(xy[2])) ^ 2 - xy[1] ^ 2) matrix(c(xy[1], xy[1], xy[2], xy[2], z, -z), ncol = 3) })) colnames(XYZ) <- c("x", "y", "z") unique(XYZ) } # XYZ_lasso <- pen_lasso(0.1) # XYZ_elnet <- pen_elnet(0.01, 0.7) # XYZ_group_lasso <- pen_group_lasso(0.01) # XYZ_lasso_tris <- geometry::convhulln(XYZ_lasso) # XYZ_elnet_tris <- geometry::convhulln(XYZ_elnet) # XYZ_group_lasso_tris <- geometry::convhulln(XYZ_group_lasso) # # Mesh generation is slow, therefore saved and loaded # saveRDS(list( # XYZ_lasso = XYZ_lasso, XYZ_elnet = XYZ_elnet, # XYZ_group_lasso = XYZ_group_lasso, # XYZ_lasso_tris = XYZ_lasso_tris, XYZ_elnet_tris = XYZ_elnet_tris, # XYZ_group_lasso_tris = XYZ_group_lasso_tris), # "~/Documents/PhD/stat-learning-for-bigdata/lectures/presentations/3d_meshes.Rds") meshes <- readRDS("~/Documents/PhD/stat-learning-for-bigdata/lectures/presentations/lasso_elnet_meshes.Rds") # Create axes x0 <- 1.5 XYZ_axes <- matrix(c( -x0, 0, 0, x0, 0, 0, 0, -x0, 0, 0, x0, 0, 0, 0, -x0, 0, 0, x0 ), ncol = 3, byrow = TRUE) # Create axes labels x0_lab <- 1.7 XYZ_labels <- matrix(c( x0_lab, 0, 0, 0, x0_lab, 0, 0, 0, x0_lab ), ncol = 3, byrow = TRUE) ## ----elnet-3d-plot, fig.width=3, fig.height=1.6, fig.align="center", rgl=TRUE, echo=FALSE, dev="png", dpi=300, results=FALSE, warning=FALSE, message=FALSE, cache=TRUE---- bg3d(color = bgCol) mfrow3d(nr = 1, nc = 2) text3d( c(-0.5, 2.1, 1), texts = "Lasso", cex = 1.8) segments3d(XYZ_axes) plotmath3d(XYZ_labels, text = c( TeX("$\\beta_1$"), TeX("$\\beta_2$"), TeX("$\\beta_3$")), fixedSize = FALSE, cex = 0.4) triangles3d( meshes$XYZ_lasso[as.vector(t(meshes$XYZ_lasso_tris)),], col = cbPalette[3]) rgl.viewpoint(theta = 30, phi = 15, zoom = 0.6) next3d() text3d( c(0.1, 2.1, 1), texts = "Elastic net (α = 0.7)", cex = 1.8) segments3d(XYZ_axes) plotmath3d(XYZ_labels, text = c( TeX("$\\beta_1$"), TeX("$\\beta_2$"), TeX("$\\beta_3$")), fixedSize = FALSE, cex = 0.4) triangles3d( meshes$XYZ_elnet[as.vector(t(meshes$XYZ_elnet_tris)),], col = cbPalette[3]) rgl.viewpoint(theta = 30, phi = 15, zoom = 0.6) ## ----group-lasso-3d-plot, fig.width=3, fig.height=1.6, fig.align="center", rgl=TRUE, dev="png", dpi=300, rgl.margin=0, echo=FALSE, results=FALSE, warning=FALSE, message=FALSE, cache=TRUE---- bg3d(color = bgCol) mfrow3d(nr = 1, nc = 2) text3d( c(-0.5, 2.1, 1), texts = "Lasso", cex = 1.8) segments3d(XYZ_axes) plotmath3d(XYZ_labels, text = c( TeX("$\\beta_1$"), TeX("$\\beta_2$"), TeX("$\\beta_3$")), fixedSize = FALSE, cex=0.4) triangles3d( meshes$XYZ_lasso[as.vector(t(meshes$XYZ_lasso_tris)),], col = cbPalette[3]) rgl.viewpoint(theta = 30, phi = 15, zoom = 0.6) next3d() text3d( c(0.1, 2.1, 1), texts = "Group lasso", cex = 1.8) segments3d(XYZ_axes) plotmath3d(XYZ_labels, text = c( TeX("$\\beta_1$"), TeX("$\\beta_2$"), TeX("$\\beta_3$")), fixedSize = FALSE, cex=0.4) triangles3d( meshes$XYZ_group_lasso[as.vector(t(meshes$XYZ_group_lasso_tris)),], col = cbPalette[3]) rgl.viewpoint(theta = 30, phi = 15, zoom = 0.6)