

reticulate::use_condaenv("r-reticulate", required = T)

library(keras)
library(tensorflow)
library(Matrix)
library(glmnet)
library(SGL)
library(ggplot2)
library(rstudioapi)
library(reticulate)

main.path <- dirname(rstudioapi::getSourceEditorContext()$path)
# directories
python.path <- file.path(main.path, "python")
temp.path <-  file.path(main.path, "temp")
plot.path <-  file.path(main.path, "plots")
setwd(main.path)

# Parameters
num_samples = 200L
predictor_dim = 200L
num_groups = 40L
noisy_variance = 1
output_dim = 1L
thresh = 1e-6
lr = 0.05 #0.05
mom = 0.9
batch_size = num_samples 
cosine_scheduler = TRUE
dep = 2L
epochs = 1500L
num_lambdas = 30L
verbose = 0
verbose_cb = 0
range_lam_lower = log(2e-3)
range_lam_upper = log(10)
group_sparsity = 0.8
init_seed = 123
run_seed = 42L

# cosine decayed learning rate
if (cosine_scheduler == TRUE){
  lr_scheduler <- learning_rate_schedule_cosine_decay( 
    initial_learning_rate = lr, 
    decay_steps = epochs, 
    alpha = 0,
    name = "CosineDecay") 
  lr_scheduler3 <- learning_rate_schedule_cosine_decay(
    initial_learning_rate = lr/2, 
    decay_steps = epochs, 
    alpha = 0,
    name = "CosineDecay2")
  lr_scheduler4 <- learning_rate_schedule_cosine_decay(
    initial_learning_rate = lr/2.5, 
    decay_steps = epochs, 
    alpha = 0,
    name = "CosineDecay2")
} else {lr_scheduler = lr
lr_scheduler3 = lr/2
lr_scheduler4 = lr/2.5}


# Lambda grid
lambda_grid <- exp(seq(range_lam_lower, range_lam_upper, length.out = num_lambdas))
lambda_grid <- c(0, 1e-5, 1e-4, 1e-3, lambda_grid, 15)
lambda_grid <- c(0, 1e-5, 1e-4, 1e-3, 2e-3, 3e-3, 5e-3, 7e-3, 9e-3, 
                 1e-2, 2e-2, 3e-2, 4e-2, 5e-2, 6e-2, 7e-2, 8e-2, 9e-2, 1e-1,
                 0.2, 0.3, 0.4, 0.6, 0.8, 1, 1.5, 2, 3, 4, 5, 6, 7, 10, 15)

# Custom layers and regularizers

layer_hadamard = function(units, la, depth, ...) {
  layers <- reticulate::import_from_path("layers", path = python.path)
  layers$HadamardLayer(units = units, la = la, depth = depth, ...)
}

layer_hadamardFC = function(units, la, depth, ...) {
  layers <- reticulate::import_from_path("layers", path = python.path)
  layers$HadamardDense(units = units, la = la, depth = depth, ...)
}

layer_group_hadamard2 = function(units, group_idx, la, depth, ...){
  layers <- reticulate::import_from_path("layers", path = python.path)
  layers$GroupHadamardLayer(units=units, group_idx=group_idx, la=la, depth=depth, ...)
}

layer_group_hadamard = function(units, group_idx, la, depth, ...){
  layers <- reticulate::import_from_path("layers", path = python.path)
  layers$TibGroupLasso(units=units, group_idx=group_idx, la=la, ...)
}

regularizer_l21 = function(la, group_idx, ...){
  layers <- reticulate::import_from_path("layers", path=python.path)
  layers$ExplicitGroupLasso(la=la, group_idx=group_idx)
}

# sparsity callback
layers <- reticulate::import_from_path("layers", path = python.path)
hadamard_sparsity_callback <- layers$HadamardCallback
sparsity_cb <- hadamard_sparsity_callback(save_metrics = FALSE, verbose = verbose_cb)

################################################################################

# Data generation

set.seed(42)
tf$random$set_seed(42)

isotropic_predictor_data <- function(num_samples, predictor_dim, output_dim, noisy_variance, 
                                     sparse = 0.8, num_groups, seed = 666) {
  set.seed(init_seed)
  x <- matrix(rnorm(num_samples * predictor_dim), nrow = num_samples, ncol = predictor_dim)
  trans <- matrix(rnorm(predictor_dim * output_dim), nrow = predictor_dim, ncol = output_dim)
  
  if (sparse > 0) {
    # Divide feature indices into num_groups consecutive groups
    group_indices <- split(1:predictor_dim, cut(1:predictor_dim, breaks = num_groups, labels = FALSE))
    for (g in 1:num_groups) {
      # Generate a binary mask for this group, per output dimension
      group_mask <- as.numeric(runif(output_dim) > sparse)
      # Apply the mask to all coefficients in this group
      trans[group_indices[[g]], ] <- trans[group_indices[[g]], ] * 
        matrix(rep(group_mask, each = length(group_indices[[g]])),
               nrow = length(group_indices[[g]]), ncol = output_dim)
    }
  }
  
  set.seed(seed)
  y <- x %*% trans + matrix(rnorm(num_samples * output_dim), nrow = num_samples, ncol = output_dim) * noisy_variance
  
  return(list(data = list(x, y), trans = trans))
}


################################################################################

# generate simulated regression data
iso.dat <- isotropic_predictor_data(num_samples=num_samples+2000, 
                                    predictor_dim=predictor_dim, 
                                    output_dim = output_dim, 
                                    noisy_variance = noisy_variance, sparse=group_sparsity,
                                    num_groups = num_groups,
                                    seed=666) #666

x <- iso.dat$data[[1]][1:num_samples,]
x_intercept <- cbind(rep(1,num_samples),x)
y_unscaled <- iso.dat$data[[2]][1:num_samples]
mean_y_unscaled <- mean(y_unscaled)
var_y_unscaled <- var(y_unscaled)
y_scaled <- scale(y_unscaled)
beta_true <- iso.dat$trans
y <- as.numeric(y_unscaled)

x_test <- iso.dat$data[[1]][(num_samples+1):(num_samples+2000),]
x_test_intercept <- cbind(rep(1,2000),x_test)
y_test_unscaled <- iso.dat$data[[2]][(num_samples+1):(num_samples+2000)]
y_test_scaled <- (y_test_unscaled - mean_y_unscaled) / sqrt(var_y_unscaled)
y_test <- as.numeric(y_test_unscaled)

# Generate indices for SGL group lasso
make_sgl_indices <- function(predictor_dim, num_groups) {
  if (predictor_dim %% num_groups != 0) {
    stop("predictor_dim must be divisible by num_groups")
  }
  elements_per_group <- predictor_dim / num_groups
  group_indices <- rep(1:num_groups, each = elements_per_group)
  
  return(group_indices)
}

sgl_indices <- make_sgl_indices(predictor_dim=predictor_dim, num_groups=num_groups)

make_keras_indices <- function(predictor_dim, num_groups) {
  if (predictor_dim %% num_groups != 0) {
    stop("predictor_dim must be divisible by num_groups")
  }
  
  group_indices <- list()
  elements_per_group <- predictor_dim / num_groups
  for (group in 1:num_groups) {
    start_index <- (group - 1) * elements_per_group
    end_index <- group * elements_per_group - 1
    group_indices[[group]] <- as.list(as.integer(start_index:end_index))
  }
  
  return(group_indices)
}

sgd_indices <- make_keras_indices(predictor_dim=predictor_dim, num_group=num_groups)


# Prepare results data frame
results <- data.frame(lambda=rep(lambda_grid, each=5), 
                      model=rep(c("Group Lasso", "Direct L21", "D=2", "D=3", "D=4"), times=length(lambda_grid)), 
                      sparsity_ratio=numeric(5*length(lambda_grid)), 
                      l1_norm=numeric(5*length(lambda_grid)),
                      test_error = numeric(5*length(lambda_grid)),
                      misalignment = numeric(5*length(lambda_grid))
                      )

for (g in 1:num_groups) {
  results[[paste0("group", g)]] <- numeric(5*length(lambda_grid))
}

# create initialization vector for all SGD based models
set.seed(init_seed)
init_matrix <- matrix(rnorm(predictor_dim, sd = sqrt(1/predictor_dim)), nrow = predictor_dim/num_groups, ncol = num_groups)

# compute models for each lambda value
set.seed(42)
tf$random$set_seed(42)
j = 1
for (i in seq_along(lambda_grid)) {
  l <- lambda_grid[i]
  
  # Group Lasso with SGL
  ##############################################################################
  group_size = predictor_dim / num_groups
  lam_sgl = l/sqrt(group_size) #2*l/sqrt(num_groups)
  fit_sgl <- SGL(data=list(x=x,y=y), index=sgl_indices, standardize = FALSE, verbose = TRUE, 
                 alpha = 0, lambdas = c(lam_sgl))
  #coef_sgl <- c(fit_sgl$intercept, fit_sgl$beta)
  coef_sgl <- fit_sgl$beta
  
  #fit_glmnet <- glmnet(x, y, alpha=1, lambda=l, standardize=FALSE, intercept = FALSE)
  #coef_glmnet <- coef(fit_glmnet, s=l)[2:(predictor_dim+1),1]
  
  # Compute test error for SGL: prediction = x_test %*% coef_sgl
  pred_sgl <- x_test %*% coef_sgl
  test_error_sgl <- sqrt(mean((pred_sgl - y_test)^2))
  
  # Compute group norms for SGL model
  group_norms_sgl <- sapply(1:num_groups, function(g) {
    idx <- which(sgl_indices == g)
    sqrt(sum(coef_sgl[idx]^2))
  })
  sparsity_ratio_sgl <- mean(group_norms_sgl < thresh)
  
  # Sparsity ratio and L1 norm for sgl
  results$sparsity_ratio[results$lambda==l & results$model=="Group Lasso"] <- sparsity_ratio_sgl
  results$l1_norm[results$lambda==l & results$model=="Group Lasso"] <- sum(abs(coef_sgl))
  results$test_error[results$lambda==l & results$model=="Group Lasso"] <- test_error_sgl
  results$misalignment[results$lambda==l & results$model=="Group Lasso"] <- 0
  for (g in 1:num_groups) {
    colname <- paste0("group", g)
    results[results$lambda==l & results$model=="Group Lasso", colname] <- group_norms_sgl[g]
  }
  cat(paste0("SGL group lasso sparsity=", 
             round(results$sparsity_ratio[results$lambda==l & results$model=="Group Lasso"],digits=4), 
             " with L1 norm=", 
             round(results$l1_norm[results$lambda==l & results$model=="Group Lasso"], digits=3), " at lam=", 
             round(l, digits=3), " with test error = ", round(test_error_sgl,3),  "\n"))
  
  # L21 penalty with GD in Keras model
  ##############################################################################
  set.seed(run_seed)
  tf$random$set_seed(run_seed)
  model <- keras_model_sequential() %>%
    layer_dense(units=output_dim, input_shape=c(ncol(x)), 
                activation='linear', 
                kernel_regularizer=regularizer_l21(la=2*l, group_idx = sgd_indices),
                use_bias = FALSE)
  model %>% compile(loss='mean_squared_error', optimizer=optimizer_sgd(learning_rate=lr_scheduler, momentum=mom))
  
  model$build(input_shape=list(c(predictor_dim)))
  
  w_new <- model$get_weights()
  
  w_new[[1]] <- matrix(init_matrix, nrow=predictor_dim, ncol=1)
  model$set_weights(w_new)
  
  model %>% fit(x, y, epochs=epochs, batch_size=batch_size, verbose=verbose, view_metrics = FALSE)
  weights_keras <- get_weights(model)[[1]]
  
  # Compute test error for keras model
  pred_keras <- x_test %*% weights_keras
  test_error_keras <- sqrt(mean((pred_keras - y_test)^2))
  
  # Compute group norms for keras model:
  group_norms_keras <- sapply(1:num_groups, function(g) {
    idx <- unlist(sgd_indices[[g]])
    sqrt(sum(weights_keras[idx, ]^2))
  })
  sparsity_ratio_keras <- mean(group_norms_keras < thresh)
  
  # Sparsity ratio and L1 norm for keras
  results$sparsity_ratio[results$lambda==l & results$model=="Direct L21"] <- sparsity_ratio_keras
  results$l1_norm[results$lambda==l & results$model=="Direct L21"] <- sum(abs(weights_keras))
  results$test_error[results$lambda==l & results$model=="Direct L21"] <- test_error_keras
  results$misalignment[results$lambda==l & results$model=="Direct L21"] <- 0
  for (g in 1:num_groups) {
    colname <- paste0("group", g)
    results[results$lambda==l & results$model=="Direct L21", colname] <- group_norms_keras[g]
  }
  
  cat(paste0("GD+L21 sparsity=", 
             round(results$sparsity_ratio[results$lambda==l & results$model=="Direct L21"], digits=3), 
             " with L1 norm=", 
             round(results$l1_norm[results$lambda==l & results$model=="Direct L21"], digits=3), " at lam=", 
             round(l, digits=3), " with test error = ", round(test_error_keras,3),  "\n"))
  
  # L21 penalty via GHPP and l2 penalty in keras model
  ##############################################################################
  set.seed(run_seed)
  tf$random$set_seed(run_seed)
  
  model_ghpp <- keras_model_sequential() %>%
    layer_group_hadamard2(units = 1L, group_idx=sgd_indices, la = l, depth=2L)()
  model_ghpp %>% compile(loss='mean_squared_error', 
                         optimizer=optimizer_sgd(learning_rate=lr_scheduler, momentum = mom))
  model_ghpp$build(input_shape=list(NULL,c(predictor_dim)))
  
  new_init <- model_ghpp$get_weights()
  
  # Commented out portion is for balanced initializations
  # Replace group weights: weights[[2]] to weights[[41]] with columns of init_matrix.
  for (i in 1:num_groups) {
    #raw_init_group <- matrix(init_matrix[, i], nrow = group_size, ncol = 1)
    #norm2_init_group <- sqrt(sum(raw_init_group^2))
    #raw_init_group_scaled <- raw_init_group / sqrt(norm2_init_group)
    #new_init[[i + 1]] <- raw_init_group_scaled
    #new_init[[1]][i] <- sqrt(norm2_init_group)
    new_init[[i + 1]] <- matrix(init_matrix[, i], nrow = group_size, ncol = 1)
  }
  
  # Set the new weights into the model.
  model_ghpp$set_weights(new_init)
  
  model_ghpp %>% fit(x, y, epochs=epochs, 
                     batch_size=batch_size, 
                     verbose=verbose, 
                     view_metrics = FALSE)
  # compute betas of linear model with group lasso penalization
  raw_weights <- model_ghpp$weights
  weights <- lapply(raw_weights, as.matrix)
  for (k in 1:num_groups) {
    weights[[k + 1]] <- weights[[k + 1]] * weights[[1]][k]
  }
  weights[[1]] <- NULL
  weights_ghpp <- unlist(weights)
  
  # Compute group-wise misalignment vector
  misalignment2 <- sapply(1:num_groups, function(g) {
    # extract the unscaled group weight vector for group g
    w_g <- as.vector(as.matrix(raw_weights[[g + 1]]))
    # extract the scaling factor
    alpha <- as.numeric(raw_weights[[1]])[g]
    # reconstruct group-wise weight vector
    w <- alpha * w_g
    # compute squared norm of group vector
    norm_sq <- sum((w_g) ^2)
    # compute the 2-norm of the group vector
    norm_w <- sqrt(sum(w^2))
    
    # compute misalignment: squared norm of w_g plus alpha squared minus 
    # the 2-norm of the product alpha * w_g 
    (norm_sq + alpha^2)/2 - norm_w
  })
  
  # Compute test error for GHPP model:
  pred_ghpp <- model_ghpp %>% predict(x_test)
  test_error_ghpp <- sqrt(mean((pred_ghpp - y_test)^2))
  
  # Compute group norms for GHPP model:
  group_norms_ghpp <- sapply(1:num_groups, function(g) {
    idx <- which(sgl_indices == g)
    sqrt(sum(weights_ghpp[idx]^2))
  })
  sparsity_ratio_ghpp <- mean(group_norms_ghpp < thresh)
  
  # Sparsity ratio and L1 norm for HPP
  results$sparsity_ratio[results$lambda==l & results$model=="D=2"] <- sparsity_ratio_ghpp
  results$l1_norm[results$lambda==l & results$model=="D=2"] <- sum(abs(weights_ghpp))
  results$test_error[results$lambda==l & results$model=="D=2"] <- test_error_ghpp
  results$misalignment[results$lambda==l & results$model=="D=2"] <- sum(misalignment2)
  for (g in 1:num_groups) {
    colname <- paste0("group", g)
    results[results$lambda==l & results$model=="D=2", colname] <- group_norms_ghpp[g]
  }
  
  cat(paste0("GHPP+L2 sparsity=", 
             round(results$sparsity_ratio[results$lambda==l & results$model=="D=2"], digits=3), 
             " with L1 norm=", 
             round(results$l1_norm[results$lambda==l & results$model=="D=2"], digits=3), " at lam=", 
             round(l, digits=3), " with test error = ", round(test_error_ghpp,3), 
             " and misalign = ", round(results$misalignment[results$lambda==l & results$model=="D=2"], digits=5)  ,"\n"))
  
  # L2,2/3 penalty via GHPP and l2 penalty in keras model
  ##############################################################################
  set.seed(run_seed)
  tf$random$set_seed(run_seed)
  
  model_ghpp3 <- keras_model_sequential() %>%
    layer_group_hadamard2(units = 1L, group_idx=sgd_indices, la = l*(2/3), depth=3L)()
  model_ghpp3 %>% compile(loss='mean_squared_error', 
                         optimizer=optimizer_sgd(learning_rate=lr_scheduler3, momentum = mom))
  
  model_ghpp3$build(input_shape=list(NULL,c(predictor_dim)))
  
  new_init <- model_ghpp3$get_weights()
  
  # Replace group weights: weights[[2]] to weights[[41]] with columns of init_matrix.
  for (i in 1:num_groups) {
    new_init[[i + 1]] <- matrix(init_matrix[, i], nrow = group_size, ncol = 1)
  }
  
  # Set the new weights into the model.
  model_ghpp3$set_weights(new_init)
  
  model_ghpp3 %>% fit(x, y, epochs=epochs, 
                     batch_size=batch_size, 
                     verbose=verbose, 
                     view_metrics = FALSE)
  
  # compute betas of linear model with non-convex structured 2,2/3 penalization
  raw_weights <- model_ghpp3$weights
  weights <- lapply(raw_weights, as.matrix)
  for (k in 1:num_groups) {
    weights[[k + 1]] <- weights[[k + 1]] * weights[[1]][k] * weights[[num_groups+2]][k]
  }
  
  # remove gating variables from weight vector
  weights[[num_groups+2]] <- NULL
  weights[[1]] <- NULL
  # finalize weight vector
  weights_ghpp3 <- unlist(weights)
  
  # Compute group-wise misalignment vector
  misalignment3 <- sapply(1:num_groups, function(g) {
    # extract the unscaled group weight vector for group g
    w_g <- as.vector(as.matrix(raw_weights[[g + 1]]))
    # extract the scaling factors
    alpha1 <- as.numeric(raw_weights[[1]])[g]
    alpha2 <- as.numeric(raw_weights[[num_groups+2]])[g]
    
    # reconstruct group-wise weight vector
    w <- alpha1 * alpha2 *  w_g
    # compute squared norm of group vector
    norm_sq <- sum((w_g)^2)
    # compute the 2-norm of the group vector to the 2/3-th power
    norm_w <- sqrt(sum(w^2))^(2/3)
    
    # compute misalignment: squared norm of w_g plus alpha squared minus 
    # the 2-norm of the product alpha * w_g (which equals |alpha| * norm_w)
    (norm_sq + alpha1^2+alpha2^2)/3 - norm_w
  })
  
  # Compute test error for GHPP model:
  pred_ghpp3 <- model_ghpp3 %>% predict(x_test)
  test_error_ghpp3 <- sqrt(mean((pred_ghpp3 - y_test)^2))
  
  # Compute group norms for GHPP model:
  group_norms_ghpp3 <- sapply(1:num_groups, function(g) {
    idx <- which(sgl_indices == g)
    sqrt(sum(weights_ghpp3[idx]^2))
  })
  sparsity_ratio_ghpp3 <- mean(group_norms_ghpp3 < thresh)
  
  # Sparsity ratio and L1 norm for HPP
  results$sparsity_ratio[results$lambda==l & results$model=="D=3"] <- sparsity_ratio_ghpp3
  results$l1_norm[results$lambda==l & results$model=="D=3"] <- sum(abs(weights_ghpp3))
  results$test_error[results$lambda==l & results$model=="D=3"] <- test_error_ghpp3
  results$misalignment[results$lambda==l & results$model=="D=3"] <- sum(misalignment3)
  for (g in 1:num_groups) {
    colname <- paste0("group", g)
    results[results$lambda==l & results$model=="D=3", colname] <- group_norms_ghpp3[g]
  }
  
  cat(paste0("GHPP+L2,2/3 sparsity=", 
             round(results$sparsity_ratio[results$lambda==l & results$model=="D=3"], digits=5), 
             " with L1 norm=", 
             round(results$l1_norm[results$lambda==l & results$model=="D=3"], digits=3), " at lam=", 
             round(l, digits=3), " with test error = ", round(test_error_ghpp3,3),
             " and misalign = ", round(results$misalignment[results$lambda==l & results$model=="D=3"], digits=5)  ,"\n"))
  
  
  # L2,2/4 penalty via GHPP and l2 penalty in keras model
  ##############################################################################
  set.seed(run_seed)
  tf$random$set_seed(run_seed)
  
  model_ghpp4 <- keras_model_sequential() %>%
    layer_group_hadamard2(units = 1L, group_idx=sgd_indices, la = l*(2/4), depth=4L)()
  model_ghpp4 %>% compile(loss='mean_squared_error', 
                          optimizer=optimizer_sgd(learning_rate=lr_scheduler4, momentum = mom))
  
  model_ghpp4$build(input_shape=list(NULL,c(predictor_dim)))
  
  new_init <- model_ghpp4$get_weights()
  
  # Replace group weights: weights[[2]] to weights[[41]] with columns of init_matrix.
  for (i in 1:num_groups) {
    new_init[[i + 1]] <- matrix(init_matrix[, i], nrow = group_size, ncol = 1)
  }
  # Set the new weights
  model_ghpp4$set_weights(new_init)
  
  # Fit model
  model_ghpp4 %>% fit(x, y, epochs=epochs, 
                      batch_size=batch_size, 
                      verbose=verbose, 
                      view_metrics = FALSE)
  # compute params of linear model with non-convex structured 2,2/4 penalization
  raw_weights <- model_ghpp4$weights
  weights <- lapply(raw_weights, as.matrix)
  for (k in 1:num_groups) {
    weights[[k + 1]] <- weights[[k + 1]] * weights[[1]][k] * weights[[num_groups+2]][k] * weights[[num_groups+3]][k]
  }
  
  weights[[num_groups+3]] <- NULL
  weights[[num_groups+2]] <- NULL
  weights[[1]] <- NULL
  weights_ghpp4 <- unlist(weights)
  
  # Compute group-wise misalignment vector
  misalignment4 <- sapply(1:num_groups, function(g) {
    # extract the unscaled group weight vector for group g
    w_g <- as.vector(as.matrix(raw_weights[[g + 1]]))
    # extract the scaling factors
    alpha1 <- as.numeric(raw_weights[[1]])[g]
    alpha2 <- as.numeric(raw_weights[[num_groups+2]])[g]
    alpha3 <- as.numeric(raw_weights[[num_groups+3]])[g]
    
    # reconstruct group-wise weight vector
    w <- alpha1 * alpha2 * alpha3 *  w_g
    # check is dimension of w is equal to w_g
    if (length(w) != length(w_g)) {
      stop("Dimension of w is not equal to w_g")
    }
    # compute squared norm of group vector
    norm_sq <- sum((w_g)^2)
    # compute the 2-norm of the group vector to the 2/4-th power
    norm_w <- sqrt(sum(w^2))^(2/4)
    
    # compute misalignment: squared norm of w_g plus alpha squared minus 
    # the 2-norm of the product alpha * w_g (which equals |alpha| * norm_w)
    (norm_sq + alpha1^2+alpha2^2+alpha3^2)/4 - norm_w
  })
  
  # Compute test error for GHPP model:
  pred_ghpp4 <- model_ghpp4 %>% predict(x_test)
  test_error_ghpp4 <- sqrt(mean((pred_ghpp4 - y_test)^2))
  
  # Compute group norms for GHPP model:
  group_norms_ghpp4 <- sapply(1:num_groups, function(g) {
    idx <- which(sgl_indices == g)
    sqrt(sum(weights_ghpp4[idx]^2))
  })
  sparsity_ratio_ghpp4 <- mean(group_norms_ghpp4 < thresh)
  
  # Sparsity ratio and L1 norm for HPP
  results$sparsity_ratio[results$lambda==l & results$model=="D=4"] <- sparsity_ratio_ghpp4
  results$l1_norm[results$lambda==l & results$model=="D=4"] <- sum(abs(weights_ghpp4))
  results$test_error[results$lambda==l & results$model=="D=4"] <- test_error_ghpp4
  results$misalignment[results$lambda==l & results$model=="D=4"] <- sum(misalignment4)
  for (g in 1:num_groups) {
    colname <- paste0("group", g)
    results[results$lambda==l & results$model=="D=4", colname] <- group_norms_ghpp4[g]
  }
  
  cat(paste0("GHPP+L2,2/4 sparsity=", 
             round(results$sparsity_ratio[results$lambda==l & results$model=="D=4"], digits=3), 
             " with L1 norm=", 
             round(results$l1_norm[results$lambda==l & results$model=="D=4"], digits=3), " at lam=", 
             round(l, digits=3), " with test error = ", round(test_error_ghpp4,3),
             " and misalign = ", round(results$misalignment[results$lambda==l & results$model=="D=4"], digits=5)  ,"\n"))
  
  
  cat(paste0("Run ", j, " finished \n"))
  j = j+1
}

timestamp <- format(Sys.time(), "%Y%m%d_%H%M%S")
saveRDS(results, file = file.path(temp.path, paste0("linmod_res_single_run_seed_", init_seed, "_", timestamp, ".rds")))

# save results as csv for python pandas post-processing
write.csv(results, file = file.path(temp.path, paste0("linmod_res_single_run_seed_", init_seed, "_", timestamp, ".csv")))

################################################################################
# plotting
library(ggplot2)
library(dplyr)
library(tidyr)

# Data preparation
plot.res <- results %>% na.omit()
plot.res$l1_norm[plot.res$l1_norm <= 1e-30] <- plot.res$l1_norm[plot.res$l1_norm <= 1e-30] + 1e-40

# Define a base theme with the desired customizations
base_theme <- theme_minimal() +
  theme(
    text = element_text(size = 14),
    axis.title = element_text(size = 16),
    axis.text = element_text(size = 14),
    legend.title = element_text(size = 14),
    legend.text = element_text(size = 12),
    plot.title = element_text(size = 18, hjust = 0.5),
    panel.grid.major = element_line(color = "grey80"),
    panel.grid.minor = element_line(color = "grey90")
  )

# Plot 1: Sparsity Ratio vs log10(lambda)
p1 <- ggplot(plot.res, aes(x = log10(lambda), y = sparsity_ratio, color = model)) +
  geom_line(size = 1.2) +
  geom_point(size = 2) +
  labs(title = "Group Sparsity vs Lambda",
       x = "Lambda (log10)", y = "Sparsity") +
  base_theme

ggsave(file.path(plot.path, paste0("p1_init", init_seed, ".pdf")), plot = p1, width = 9, height = 5)
ggsave(file.path(plot.path, paste0("p1_init", init_seed, ".png")), plot = p1, width = 9, height = 5)

# Plot 2: Log10(L1 Norm) of Weights vs Log10(lambda)
p2 <- ggplot(plot.res, aes(x = log10(lambda), y = log10(l1_norm), color = model)) +
  geom_line(size = 1.2) +
  geom_point(size = 2) +
  ylim(-40,10) +
  labs(title = "L1 weight norm vs Lambda",
       x = "Lambda (log10)", y = "L1 Norm of Weights (Log10)") +
  base_theme

ggsave(file.path(plot.path, paste0("p2_init", init_seed, ".pdf")), plot = p2, width = 9, height = 5)
ggsave(file.path(plot.path, paste0("p2_init", init_seed, ".png")), plot = p2, width = 9, height = 5)

# Plot 3: Log10(Lambda) vs RMSE
p3 <- ggplot(plot.res, aes(x = log10(lambda), y = test_error, color = model)) +
  geom_line(size = 1.2) +
  geom_point(size = 2) +
  labs(title = "RMSE vs Lambda",
       x = "Log10(lambda)", y = "RMSE") +
  base_theme

ggsave(file.path(plot.path, paste0("p3_init", init_seed, ".pdf")), plot = p3, width = 9, height = 5)
ggsave(file.path(plot.path, paste0("p3_init", init_seed, ".png")), plot = p3, width = 9, height = 5)

# Plot 4: Group Sparsity vs RMSE
p4 <- ggplot(plot.res, aes(x = sparsity_ratio, y = test_error, color = model)) +
  geom_line(size = 1.2) +
  geom_point(size = 2) +
  labs(title = "RMSE vs Group Sparsity",
       x = "Group Sparsity", y = "RMSE") +
  base_theme

ggsave(file.path(plot.path, paste0("p4_init", init_seed, ".pdf")), plot = p4, width = 9, height = 5)
ggsave(file.path(plot.path, paste0("p4_init", init_seed, ".png")), plot = p4, width = 9, height = 5)

# Plot 5: Misalignment at end of training vs log10 lambda
p5 <- ggplot(plot.res, aes(x = log10(lambda), y = misalignment, color = model)) +
  geom_line(size = 1.2) +
  geom_point(size = 2) +
  labs(title = "Misalignment vs Lambda",
       x = "Lambda (log10)", y = "Misalignment") +
  base_theme

ggsave(file.path(plot.path, paste0("p5_init", init_seed, ".pdf")), plot = p5, width = 9, height = 5)
ggsave(file.path(plot.path, paste0("p5_init", init_seed, ".png")), plot = p5, width = 9, height = 5)

# Reshape the data from wide to long format for the group columns
# reduce data to models D=2, Direct L21, and SLG
plot.res.red1 <- plot.res %>% filter(model %in% c("D=2", "Group Lasso"))
plot.res.red2 <- plot.res %>% filter(model %in% c("D=2", "Direct L21", "Group Lasso"))


plot_long1 <- pivot_longer(plot.res.red1, 
                          cols = starts_with("group"),
                          names_to = "group",
                          values_to = "norm")
plot_long2 <- pivot_longer(plot.res.red2, 
                           cols = starts_with("group"),
                           names_to = "group",
                           values_to = "norm")

# Compute scaling factor for misalignment (only for model "D=2")
norm_max <- max(plot_long1$norm, na.rm = TRUE)
misalignment_max <- max(subset(plot.res, model == "D=2")$misalignment, na.rm = TRUE)
scale_factor <- norm_max / misalignment_max
df_D2 <- subset(plot.res, model == "D=2")
df_D2 <- df_D2[order(df_D2$lambda), ]
lambda_threshold <- df_D2$lambda[which(df_D2$misalignment < 1e-4)[1]]

p6 <- ggplot() +
  # Plot group 2-norms for all models and groups
  geom_line(data = plot_long1, 
            aes(x = log10(lambda), y = norm, group = interaction(model, group), color = model),
            size = 1.2) +
  geom_point(data = plot_long1, 
             aes(x = log10(lambda), y = norm, group = interaction(model, group), color = model),
             size = 2.2) +
  # Overlay misalignment for model "D=2", scaled to primary y-axis
  geom_line(data = subset(plot.res, model == "D=2"),
            aes(x = log10(lambda), y = misalignment * scale_factor),
            size = 1.2, color = "black") +
  geom_point(data = subset(plot.res, model == "D=2"),
             aes(x = log10(lambda), y = misalignment * scale_factor),
             size = 2.2, color = "black") +
  scale_y_continuous(name = "Group-wise weight norms",
                     sec.axis = sec_axis(~ . / scale_factor, name = "Post-training misalignment")) +
  labs(title = "Transition coincides with zero misalignment", x = "Lambda (log10)", color = "Model") +
  geom_vline(xintercept = log10(lambda_threshold), color = "black", linetype = "dashed", linewidth=1.5) +
  base_theme +
  theme(axis.text = element_text(size = 14),      
        axis.title = element_text(size = 16),       
        plot.title = element_text(size = 18),  
        legend.text = element_text(size = 14),        
        legend.title = element_text(size = 16),
        axis.ticks = element_line(size = 1.),        
        axis.ticks.length = unit(0.1, "cm"),
        axis.text.x = element_text(size = 18),
        axis.text.y = element_text(size = 18),
        axis.text.y.right = element_text(size = 18)) 

ggsave(file.path(plot.path, paste0("p6_init", init_seed, ".pdf")), plot = p6, width = 9, height = 5)
ggsave(file.path(plot.path, paste0("p6_init", init_seed, ".png")), plot = p6, width = 9, height = 5)

# Plot 7: log10(2-norm of groups) vs log10(lambda), grouped by model and group
p7 <- ggplot(plot_long2, aes(x = log10(lambda), y = log10(norm), 
                            group = interaction(model, group), 
                            color = model)) +
  geom_line(size = 1.2) +
  geom_point(size = 2) +
  labs(title = "2-norm of Groups vs Lambda",
       x = "Lambda (log10)", y = "Group weight norms (log10)") +
  base_theme

ggsave(file.path(plot.path, paste0("p7_init", init_seed, ".pdf")), plot = p7, width = 9, height = 5)
ggsave(file.path(plot.path, paste0("p7_init", init_seed, ".png")), plot = p7, width = 9, height = 5)

# Display plots
print(p1)
print(p2)
print(p3)
print(p4)
print(p5)
print(p6)
print(p7)

