library(MASS)
library(mvtnorm)
library(fairml)
library("Renvlp")
library(pracma)
library(transport)
library(T4transport)
library(dplyr)
library(ggplot2)
library(data.table)
library(dplyr)
library(stringr)

preprocess_insurance_data <- function(df) {
  library(dplyr)
  library(fastDummies)
  
  # 1. Extract Y (charges)
  Y <- df$charges
  
  # 2. Extract S (gender, age, medical_history, family_medical_history)
  S <- df %>% select(gender, age, medical_history, family_medical_history, region)
  
  # Dummy encode string columns in S (excluding age which is numeric)
  char_cols_S <- names(Filter(is.character, S))
  if (length(char_cols_S) > 0) {
    S <- fastDummies::dummy_cols(S, select_columns = char_cols_S, remove_selected_columns = TRUE)
  }
  
  # 3. Extract X (all others except Y and S variables)
  X <- df %>% select(-charges, -gender, -age, -medical_history, -family_medical_history, -region)
  
  # Dummy encode string columns in X
  char_cols_X <- names(Filter(is.character, X))
  if (length(char_cols_X) > 0) {
    X <- fastDummies::dummy_cols(X, select_columns = char_cols_X, remove_selected_columns = TRUE)
  }
  
  return(list(Y = Y, S = S, X = X))
}


library(readr)

# n <- 500
# train_ratio <- 0.8
# df_all <- df[1:n,]
# 
# result <- preprocess_insurance_data(df_all)
# Y <- scale(result$Y)
# S <- result$S
# S <- S + matrix(0.1 * rnorm(dim(S)[1] * dim(S)[2]), dim(S)[1], dim(S)[2])
# S <- scale(S)
# 
# X <- result$X
# X <- X + matrix(0.1 * rnorm(dim(X)[1] * dim(X)[2]), dim(X)[1], dim(X)[2])
# X <- scale(X)
# 
# train_idx <- sample(1:n, size = round(train_ratio * n))
# test_idx <- setdiff(1:n, train_idx)
# 
# S_train <- as.matrix(S[train_idx, ])
# S_test <- as.matrix(S[test_idx, ])
# X_train <- as.matrix(X[train_idx, ])
# X_test <- as.matrix(X[test_idx, ])
# Y_train <- as.matrix(Y[train_idx])
# Y_test <- as.matrix(Y[test_idx])



result_ols <- function(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair) {
  aux_lm <- lm(X_train ~ S_train)
  U_train <- residuals(aux_lm)
  U_test <- X_test - cbind(1, S_test) %*% aux_lm$coefficients
  
  m_Scutari2 <- frrm(response = Y_train, sensitive = S_train, predictors = U_train, unfairness = r_fair)
  sensitivity <- attr(m_Scutari2$main$coefficients, "sensitive")
  alpha <- as.numeric(m_Scutari2$main$coefficients)[sensitivity]
  beta <- as.numeric(m_Scutari2$main$coefficients)[!sensitivity]
  
  U_all_train <- cbind(1, U_train)
  R_S <- var(S_train %*% alpha)
  R_X <- var(U_all_train %*% beta)
  Fair_SX <- R_S / (R_S + R_X)
  MS_residuals <- mean(m_Scutari2$main$residuals^2)
  
  cat("Fairness metric for OLS residual model (train):", Fair_SX, "\n")
  cat("MSE for OLS residual model (train):", MS_residuals, "\n")
  
  U_all_test <- cbind(1, U_test)
  Y_pred_test <- S_test %*% alpha + U_all_test %*% beta
  R_S_test <- var(S_test %*% alpha)
  R_all_test <- var(Y_pred_test)
  Fair_SX_test <- R_S_test / R_all_test
  MS_test_residuals <- mean((Y_test - Y_pred_test)^2)
  
  cat("Fairness metric for OLS residual model (test):", Fair_SX_test, "\n")
  cat("MSE for OLS residual model (test):", MS_test_residuals, "\n")
  
  return(list(
    Fair_SX_train = Fair_SX,
    MSE_train = MS_residuals,
    Fair_SX_test = Fair_SX_test,
    MSE_test = MS_test_residuals,
    alpha = alpha,
    beta = beta
  ))
}


result_ols_new <- function(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair) {
  aux_lm <- lm(X_train ~ S_train)
  U_train <- residuals(aux_lm)
  U_test <- X_test - cbind(1, S_test) %*% aux_lm$coefficients
  
  ols_fair_model <- function(response, predictors, sensitive, unfairness) {
    predictors_new <- cbind(1, predictors)
    beta <- solve(t(predictors_new) %*% predictors_new) %*% t(predictors_new) %*% response
    alpha <- solve(t(sensitive) %*% sensitive) %*% t(sensitive) %*% response
    Y_pred <- sensitive %*% alpha + predictors_new %*% beta
    R_S <- var(sensitive %*% alpha)
    R_all <- var(Y_pred)
    Fair_SX <- R_S / R_all
    
    if (Fair_SX < unfairness) {
      print(Fair_SX)
      return(list(alpha = alpha, beta = beta))
    } else {
      # Step 3.2: Adaptive Lambda Search
      adaptive_lambda_search <- function(sensitive, predictors_new, response, beta, unfairness, q = 100) {
        best_alpha <- alpha
        best_MSE <- Inf
        
        optimize_lambda <- function(lambda_values, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha) {
          results <- sapply(lambda_values, function(lam) {
            alpha <- solve(t(sensitive) %*% sensitive + lam * diag(ncol(sensitive))) %*% t(sensitive) %*% response
            R_S_ <- var(sensitive %*% alpha)
            Y_pred_ <- sensitive %*% alpha + predictors_new %*% beta
            Fair_SX_ <- R_S_ / var(Y_pred_)
            MSE <- mean((response - Y_pred_)^2)
            
            if (Fair_SX_ <= unfairness && MSE < best_MSE) {
              return(list(best_alpha = alpha, best_MSE = MSE, best_lambda = lam))
            } else {
              return(NULL)
            }
          }, simplify = FALSE)
          
          results <- Filter(Negate(is.null), results)
          
          if (length(results) > 0) {
            best_result <- results[[which.min(sapply(results, function(x) x$best_MSE))]]
            return(best_result)
          } else {
            return(list(best_alpha = best_alpha, best_MSE = best_MSE, best_lambda = NULL))
          }
        }
        
        lbd1 <- 10^(seq(0, 9, length.out = q))
        best_MSE <- Inf  # Initialize best_MSE
        best_alpha <- NULL
        result1 <- optimize_lambda(lbd1, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha)
        
        l1 <- result1$best_lambda
        best_alpha <- result1$best_alpha
        best_MSE <- result1$best_MSE
        
        lbd2 <- (10^seq(-5, 5, length.out = q)) * l1
        result2 <- optimize_lambda(lbd2, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha)
        
        l2 <- result2$best_lambda
        best_alpha <- result2$best_alpha
        best_MSE <- result2$best_MSE
        
        lbd3 <- (10^seq(-3, 3, length.out = q)) * l2
        result3 <- optimize_lambda(lbd3, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha)
        
        l3 <- result3$best_lambda
        best_alpha <- result3$best_alpha
        best_MSE <- result3$best_MSE
        
        return(list(best_alpha = best_alpha, best_MSE = best_MSE))
      }
      
      result <- adaptive_lambda_search(sensitive, predictors_new, response, beta, unfairness)
      return(list(alpha = result$best_alpha, beta = beta, best_MSE = result$best_MSE))
    }
  }
  
  result <- ols_fair_model(Y_train, U_train, S_train, r_fair)
  
  alpha <- result$alpha
  beta <- result$beta
  
  U_all_train <- cbind(1, U_train)
  Y_pred_train <- S_train %*% alpha + U_all_train %*% beta
  R_S <- var(S_train %*% alpha)
  R_X <- var(U_all_train %*% beta)
  Fair_SX <- R_S / (R_S + R_X)
  MS_residuals <- mean((Y_train - Y_pred_train)^2)
  
  cat("Fairness metric for OLS residual model (train):", Fair_SX, "\n")
  cat("MSE for OLS residual model (train):", MS_residuals, "\n")
  
  U_all_test <- cbind(1, U_test)
  Y_pred_test <- S_test %*% alpha + U_all_test %*% beta
  R_S_test <- var(S_test %*% alpha)
  R_all_test <- var(Y_pred_test)
  Fair_SX_test <- R_S_test / R_all_test
  MS_test_residuals <- mean((Y_test - Y_pred_test)^2)
  
  cat("Fairness metric for OLS residual model (test):", Fair_SX_test, "\n")
  cat("MSE for OLS residual model (test):", MS_test_residuals, "\n")
  
  return(list(
    Fair_SX_train = Fair_SX,
    MSE_train = MS_residuals,
    Fair_SX_test = Fair_SX_test,
    MSE_test = MS_test_residuals,
    alpha = alpha,
    beta = beta
  ))
}


result_envelop <- function(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = FALSE) {
  env_est_xs <- u.env(S_train, X_train, alpha = 0.01)
  envmodel_xs <- env(S_train, X_train, env_est_xs$u.bic)
  Gamma0_xs <- envmodel_xs$Gamma0
  
  env_est_yx <- u.xenv(X_train, Y_train, alpha = 0.01)
  envmodel_yx <- xenv(X_train, Y_train, env_est_yx$u.bic)
  Gamma_yx <- envmodel_yx$Gamma
  
  if (intersec) {
    P <- Gamma_yx %*% t(Gamma_yx) %*% Gamma0_xs %*% t(Gamma0_xs)
    svd_result <- svd(P)
    threshold <- 1e-3
    Proj <- svd_result$v[, svd_result$d > threshold]
  } else {
    Proj <- Gamma0_xs
  }
  
  envelop_fair_model <- function(response, predictors, sensitive, unfairness) {
    predictors_new <- cbind(1, predictors %*% Proj)
    beta <- solve(t(predictors_new) %*% predictors_new) %*% t(predictors_new) %*% response
    alpha <- solve(t(sensitive) %*% sensitive) %*% t(sensitive) %*% response
    Y_pred <- sensitive %*% alpha + predictors_new %*% beta
    R_S <- var(sensitive %*% alpha)
    R_all <- var(Y_pred)
    Fair_SX <- R_S / R_all
    
    if (Fair_SX < unfairness) {
      print(Fair_SX)
      return(list(alpha = alpha, beta = beta))
    } else {
      # Step 3.2: Adaptive Lambda Search
      adaptive_lambda_search <- function(sensitive, predictors_new, response, beta, unfairness, q = 100) {
        best_alpha <- alpha
        best_MSE <- Inf
        
        optimize_lambda <- function(lambda_values, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha) {
          results <- sapply(lambda_values, function(lam) {
            alpha <- solve(t(sensitive) %*% sensitive + lam * diag(ncol(sensitive))) %*% t(sensitive) %*% response
            R_S_ <- var(sensitive %*% alpha)
            Y_pred_ <- sensitive %*% alpha + predictors_new %*% beta
            Fair_SX_ <- R_S_ / var(Y_pred_)
            MSE <- mean((response - Y_pred_)^2)
            
            if (Fair_SX_ <= unfairness && MSE < best_MSE) {
              return(list(best_alpha = alpha, best_MSE = MSE, best_lambda = lam))
            } else {
              return(NULL)
            }
          }, simplify = FALSE)
          
          results <- Filter(Negate(is.null), results)
          
          if (length(results) > 0) {
            best_result <- results[[which.min(sapply(results, function(x) x$best_MSE))]]
            return(best_result)
          } else {
            return(list(best_alpha = best_alpha, best_MSE = best_MSE, best_lambda = NULL))
          }
        }
        
        lbd1 <- 10^(seq(0, 9, length.out = q))
        best_MSE <- Inf  # Initialize best_MSE
        best_alpha <- NULL
        result1 <- optimize_lambda(lbd1, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha)
        
        l1 <- result1$best_lambda
        best_alpha <- result1$best_alpha
        best_MSE <- result1$best_MSE
        
        lbd2 <- (10^seq(-5, 5, length.out = q)) * l1
        result2 <- optimize_lambda(lbd2, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha)
        
        l2 <- result2$best_lambda
        best_alpha <- result2$best_alpha
        best_MSE <- result2$best_MSE
        
        lbd3 <- (10^seq(-3, 3, length.out = q)) * l2
        result3 <- optimize_lambda(lbd3, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha)
        
        l3 <- result3$best_lambda
        best_alpha <- result3$best_alpha
        best_MSE <- result3$best_MSE
        
        return(list(best_alpha = best_alpha, best_MSE = best_MSE))
      }
      
      result <- adaptive_lambda_search(sensitive, predictors_new, response, beta, unfairness)
      return(list(alpha = result$best_alpha, beta = beta, best_MSE = result$best_MSE))
    }
  }

  result <- envelop_fair_model(Y_train, X_train, S_train, r_fair)
  
  alpha <- result$alpha
  beta <- result$beta
  
  X_new_train <- cbind(1, X_train %*% Proj)
  Y_pred_train <- S_train %*% alpha + X_new_train %*% beta
  R_all_train <- var(Y_pred_train)
  R_S <- var(S_train %*% alpha)
  Fair_SX_train <- R_S / R_all_train
  MSE_train <- mean((Y_train - Y_pred_train)^2)
  cat("Fairness metric for envelope model (train):", Fair_SX_train, "\n")
  cat("MSE for envelope model (train):", MSE_train, "\n")
  
  X_new_test <- cbind(1, X_test %*% Proj)
  Y_pred_test <- S_test %*% alpha + X_new_test %*% beta
  R_all_test <- var(Y_pred_test)
  R_S_test <- var(S_test %*% alpha)
  Fair_SX_test <- R_S_test / R_all_test
  MSE_test <- mean((Y_test - Y_pred_test)^2)
  cat("Fairness metric for envelope model (test):", Fair_SX_test, "\n")
  cat("MSE for envelope model (test):", MSE_test, "\n")
  
  return(list(
    Fair_SX_train = Fair_SX_train,
    MSE_train = MSE_train,
    Fair_SX_test = Fair_SX_test,
    MSE_test = MSE_test,
    alpha = alpha,
    beta = beta
  ))
}

# r_fair = 0.1
# 
# OLS_result <- result_ols(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair)
# Envelop_result <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = FALSE)
# Envelop_intersec_result <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = TRUE)

df_full <- read_csv('insurance_dataset.csv')
r_fair_values <- c(0.1, 0.2, 0.3)
sample_sizes <- c(2000, 4000, 8000)
repeats <- 30
train_ratio <- 0.8

set.seed(123)  # For reproducibility

# Loop through experiments
all_results <- list()

for (r_fair in r_fair_values) {
  for (n in sample_sizes) {
    for (rep in 1:repeats) {

      df_all <- df_full[1:n, ]
      # Preprocessing
      result <- preprocess_insurance_data(df_all)
      Y <- scale(result$Y)
      S <- result$S + matrix(0.01 * rnorm(n * ncol(result$S)), n, ncol(result$S))
      S <- scale(S)
      X <- result$X + matrix(0.01 * rnorm(n * ncol(result$X)), n, ncol(result$X))
      X <- scale(X)
      
      train_idx <- sample(1:n, size = round(train_ratio * n))
      test_idx <- setdiff(1:n, train_idx)
      
      S_train <- as.matrix(S[train_idx, ])
      S_test <- as.matrix(S[test_idx, ])
      X_train <- as.matrix(X[train_idx, ])
      X_test <- as.matrix(X[test_idx, ])
      Y_train <- as.matrix(Y[train_idx])
      Y_test <- as.matrix(Y[test_idx])
      
      # Run experiments
      ols <- result_ols_new(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair)
      env <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = FALSE)
      env_i <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = TRUE)
      
      # Collect results
      all_results[[length(all_results) + 1]] <- data.frame(
        method = "FRRM", r_fair = r_fair, n = n, rep = rep,
        Fair_SX_train = ols$Fair_SX_train,
        MSE_train = ols$MSE_train,
        Fair_SX_test = ols$Fair_SX_test,
        MSE_test = ols$MSE_test
      )
      
      all_results[[length(all_results) + 1]] <- data.frame(
        method = "FREM(general)", r_fair = r_fair, n = n, rep = rep,
        Fair_SX_train = env$Fair_SX_train,
        MSE_train = env$MSE_train,
        Fair_SX_test = env$Fair_SX_test,
        MSE_test = env$MSE_test
      )
      
      all_results[[length(all_results) + 1]] <- data.frame(
        method = "FREM(with S)", r_fair = r_fair, n = n, rep = rep,
        Fair_SX_train = env_i$Fair_SX_train,
        MSE_train = env_i$MSE_train,
        Fair_SX_test = env_i$Fair_SX_test,
        MSE_test = env_i$MSE_test
      )
      
    }
  }
}

# Combine and write to CSV
result_df <- bind_rows(all_results)
write_csv(result_df, "insurance_experiment_results.csv")

# Plotting

filtered_df <- result_df %>% filter(n %in% c(2000, 4000, 8000))

ggplot(filtered_df, aes(x = factor(n), y = MSE_test, fill = method)) +
  geom_boxplot() +
  facet_wrap(~ r_fair, labeller = label_both) +
  labs(title = "Test MSE vs Sample Size (r_fair = 0.1, 0.2, 0.3)",
       x = "Sample Size", y = "Test MSE") +
  theme_minimal() +
  theme(legend.position = "bottom")


ggplot(filtered_df, aes(x = factor(n), y = MSE_train, fill = method)) +
  geom_boxplot() +
  facet_wrap(~ r_fair, labeller = label_both) +
  labs(title = "Test MSE vs Sample Size (r_fair = 0.1, 0.2, 0.3)",
       x = "Sample Size", y = "Train MSE") +
  theme_minimal() +
  theme(legend.position = "bottom")


ggplot(filtered_df, aes(x = factor(n), y = Fair_SX_test, fill = method)) +
  geom_boxplot() +
  facet_wrap(~ r_fair, labeller = label_both) +
  labs(title = "Test Fairness vs Sample Size (r_fair = 0.1, 0.2, 0.3)",
       x = "Sample Size", y = "Fairness (Test)") +
  theme_minimal() +
  theme(legend.position = "bottom")








