library(MASS)
library(mvtnorm)
library(fairml)
library("Renvlp")
library(pracma)
library(transport)
library(T4transport)
library(dplyr)
library(ggplot2)
library(cccp)


data_generate <- function(n, p, p_corr, p_corr_dim, p_S, S_Q, beta_SX, beta_S, beta_X, train_ratio = 0.8) {
  S <- matrix(rnorm(p_S * n), n, p_S)
  #X_corr <- S %*% S_Q %*% t(beta_SX) + matrix(0.5 * rnorm(p_corr * n), n, p_corr)
  X_corr <- S %*% S_Q %*% t(beta_SX) + matrix(0.5 * rpois(p_corr * n,lambda = 1)-1, n, p_corr)
  X_indep <- matrix(rnorm((p - p_corr) * n), n, p - p_corr)
  X <- cbind(X_corr, X_indep)
  
  sigma_Y <- 0.5  
  Y <- S %*% beta_S + X %*% beta_X + rnorm(n, mean = 0, sd = sigma_Y)
  
  S <- scale(S)
  X <- scale(X)
  Y <- scale(Y)
  
  train_idx <- sample(1:n, size = round(train_ratio * n))
  test_idx <- setdiff(1:n, train_idx)
  
  S_train <- S[train_idx, ]
  S_test <- S[test_idx, ]
  X_train <- X[train_idx, ]
  X_test <- X[test_idx, ]
  Y_train <- as.matrix(Y[train_idx])
  Y_test <- as.matrix(Y[test_idx])
  
  return(list(
    S_train = S_train,
    S_test = S_test,
    X_train = X_train,
    X_test = X_test,
    Y_train = Y_train,
    Y_test = Y_test
  ))
}


result_ols <- function(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair) {
  aux_lm <- lm(X_train ~ S_train)
  U_train <- residuals(aux_lm)
  U_test <- X_test - cbind(1, S_test) %*% aux_lm$coefficients
  
  m_Scutari2 <- frrm(response = Y_train, sensitive = S_train, predictors = U_train, unfairness = r_fair)
  sensitivity <- attr(m_Scutari2$main$coefficients, "sensitive")
  alpha <- as.numeric(m_Scutari2$main$coefficients)[sensitivity]
  beta <- as.numeric(m_Scutari2$main$coefficients)[!sensitivity]
  
  U_all_train <- cbind(1, U_train)
  R_S <- var(S_train %*% alpha)
  R_X <- var(U_all_train %*% beta)
  Fair_SX <- R_S / (R_S + R_X)
  MS_residuals <- mean(m_Scutari2$main$residuals^2)
  
  cat("Fairness metric for OLS residual model (train):", Fair_SX, "\n")
  cat("MSE for OLS residual model (train):", MS_residuals, "\n")
  
  U_all_test <- cbind(1, U_test)
  Y_pred_test <- S_test %*% alpha + U_all_test %*% beta
  R_S_test <- var(S_test %*% alpha)
  R_all_test <- var(Y_pred_test)
  Fair_SX_test <- R_S_test / R_all_test
  MS_test_residuals <- mean((Y_test - Y_pred_test)^2)
  
  cat("Fairness metric for OLS residual model (test):", Fair_SX_test, "\n")
  cat("MSE for OLS residual model (test):", MS_test_residuals, "\n")
  
  return(list(
    Fair_SX_train = Fair_SX,
    MSE_train = MS_residuals,
    Fair_SX_test = Fair_SX_test,
    MSE_test = MS_test_residuals,
    alpha = alpha,
    beta = beta
  ))
}

result_Komi <- function(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair) {
  aux_lm <- lm(X_train ~ S_train)
  U_train <- residuals(aux_lm)
  U_test <- X_test - cbind(1, S_test) %*% aux_lm$coefficients
  
  m_Komi <- nclm(response = Y_train, sensitive = S_train, predictors = U_train, unfairness = r_fair)
  sensitivity <- attr(m_Komi$main$coefficients, "sensitive")
  alpha <- as.numeric(m_Komi$main$coefficients)[sensitivity]
  beta <- as.numeric(m_Komi$main$coefficients)[!sensitivity]
  
  U_all_train <- cbind(1, U_train)
  R_S <- var(S_train %*% alpha)
  R_X <- var(U_all_train %*% beta)
  Fair_SX <- R_S / (R_S + R_X)
  MS_residuals <- mean(m_Komi$main$residuals^2)
  
  cat("Fairness metric for Komi residual model (train):", Fair_SX, "\n")
  cat("MSE for Komi residual model (train):", MS_residuals, "\n")
  
  U_all_test <- cbind(1, U_test)
  Y_pred_test <- S_test %*% alpha + U_all_test %*% beta
  R_S_test <- var(S_test %*% alpha)
  R_all_test <- var(Y_pred_test)
  Fair_SX_test <- R_S_test / R_all_test
  MS_test_residuals <- mean((Y_test - Y_pred_test)^2)
  
  cat("Fairness metric for Komi residual model (test):", Fair_SX_test, "\n")
  cat("MSE for Komi residual model (test):", MS_test_residuals, "\n")
  
  return(list(
    Fair_SX_train = Fair_SX,
    MSE_train = MS_residuals,
    Fair_SX_test = Fair_SX_test,
    MSE_test = MS_test_residuals,
    alpha = alpha,
    beta = beta
  ))
}





result_envelop <- function(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = FALSE) {
  env_est_xs <- u.env(S_train, X_train, alpha = 0.01)
  envmodel_xs <- env(S_train, X_train, env_est_xs$u.bic)
  Gamma0_xs <- envmodel_xs$Gamma0
  
  env_est_yx <- u.xenv(X_train, Y_train, alpha = 0.01)
  envmodel_yx <- xenv(X_train, Y_train, env_est_yx$u.bic)
  Gamma_yx <- envmodel_yx$Gamma
  
  if (intersec) {
    P <- Gamma_yx %*% t(Gamma_yx) %*% Gamma0_xs %*% t(Gamma0_xs)
    svd_result <- svd(P)
    threshold <- 1e-3
    Proj <- svd_result$v[, svd_result$d > threshold]
  } else {
    Proj <- Gamma0_xs
  }
  
  envelop_fair_model <- function(response, predictors, sensitive, unfairness) {
    predictors_new <- cbind(1, predictors %*% Proj)
    beta <- solve(t(predictors_new) %*% predictors_new) %*% t(predictors_new) %*% response
    alpha <- solve(t(sensitive) %*% sensitive) %*% t(sensitive) %*% response
    Y_pred <- sensitive %*% alpha + predictors_new %*% beta
    R_S <- var(sensitive %*% alpha)
    R_all <- var(Y_pred)
    Fair_SX <- R_S / R_all
    
    if (Fair_SX < unfairness) {
      print(Fair_SX)
      return(list(alpha = alpha, beta = beta))
    } else {
      # Step 3.2: Adaptive Lambda Search
      adaptive_lambda_search <- function(sensitive, predictors_new, response, beta, unfairness, q = 100) {
        best_alpha <- alpha
        best_MSE <- Inf
        
        optimize_lambda <- function(lambda_values, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha) {
          results <- sapply(lambda_values, function(lam) {
            alpha <- solve(t(sensitive) %*% sensitive + lam * diag(ncol(sensitive))) %*% t(sensitive) %*% response
            R_S_ <- var(sensitive %*% alpha)
            Y_pred_ <- sensitive %*% alpha + predictors_new %*% beta
            Fair_SX_ <- R_S_ / var(Y_pred_)
            MSE <- mean((response - Y_pred_)^2)
            
            if (Fair_SX_ <= unfairness && MSE < best_MSE) {
              return(list(best_alpha = alpha, best_MSE = MSE, best_lambda = lam))
            } else {
              return(NULL)
            }
          }, simplify = FALSE)
          
          results <- Filter(Negate(is.null), results)
          
          if (length(results) > 0) {
            best_result <- results[[which.min(sapply(results, function(x) x$best_MSE))]]
            return(best_result)
          } else {
            return(list(best_alpha = best_alpha, best_MSE = best_MSE, best_lambda = NULL))
          }
        }
        
        lbd1 <- 10^(seq(0, 9, length.out = q))
        best_MSE <- Inf  # Initialize best_MSE
        best_alpha <- NULL
        result1 <- optimize_lambda(lbd1, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha)
        
        l1 <- result1$best_lambda
        best_alpha <- result1$best_alpha
        best_MSE <- result1$best_MSE
        
        lbd2 <- (10^seq(-5, 5, length.out = q)) * l1
        result2 <- optimize_lambda(lbd2, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha)
        
        l2 <- result2$best_lambda
        best_alpha <- result2$best_alpha
        best_MSE <- result2$best_MSE
        
        lbd3 <- (10^seq(-3, 3, length.out = q)) * l2
        result3 <- optimize_lambda(lbd3, sensitive, response, predictors_new, beta, unfairness, best_MSE, best_alpha)
        
        l3 <- result3$best_lambda
        best_alpha <- result3$best_alpha
        best_MSE <- result3$best_MSE
        
        return(list(best_alpha = best_alpha, best_MSE = best_MSE))
      }
      
      result <- adaptive_lambda_search(sensitive, predictors_new, response, beta, unfairness)
      return(list(alpha = result$best_alpha, beta = beta, best_MSE = result$best_MSE))
    }
  }
  
  result <- envelop_fair_model(Y_train, X_train, S_train, r_fair)
  
  alpha <- result$alpha
  beta <- result$beta
  
  X_new_train <- cbind(1, X_train %*% Proj)
  Y_pred_train <- S_train %*% alpha + X_new_train %*% beta
  R_all_train <- var(Y_pred_train)
  R_S <- var(S_train %*% alpha)
  Fair_SX_train <- R_S / R_all_train
  MSE_train <- mean((Y_train - Y_pred_train)^2)
  cat("Fairness metric for envelope model (train):", Fair_SX_train, "\n")
  cat("MSE for envelope model (train):", MSE_train, "\n")
  
  X_new_test <- cbind(1, X_test %*% Proj)
  Y_pred_test <- S_test %*% alpha + X_new_test %*% beta
  R_all_test <- var(Y_pred_test)
  R_S_test <- var(S_test %*% alpha)
  Fair_SX_test <- R_S_test / R_all_test
  MSE_test <- mean((Y_test - Y_pred_test)^2)
  cat("Fairness metric for envelope model (test):", Fair_SX_test, "\n")
  cat("MSE for envelope model (test):", MSE_test, "\n")
  
  return(list(
    Fair_SX_train = Fair_SX_train,
    MSE_train = MSE_train,
    Fair_SX_test = Fair_SX_test,
    MSE_test = MSE_test,
    alpha = alpha,
    beta = beta
  ))
}

#set.seed(123)
#n <- 10000   
#p <- 40     
#p_corr <- 20 
#p_corr_dim <- 5
#p_S <- 10
#r_fair_levels <- seq(0.02, 0.5, by = 0.1) 
#num_replicates <- 5 

# First set working directory to source file location

#output_folder <- "./fair_results_csv"  

#if (!dir.exists(output_folder)) {
#  dir.create(output_folder)
#}

#random_matrix <- matrix(rnorm(p_S * p_S), nrow = p_S, ncol = p_S)
#A <- qr.Q(qr(t(random_matrix) %*% random_matrix))
#Q <- A[, 1:p_corr_dim]
#S_Q <- Q %*% t(Q)

#beta_SX <- matrix(rnorm(p_corr * p_S), p_corr, p_S) + 2
#beta_S <- matrix(rnorm(p_S), p_S, 1) + 1
#beta_X <- matrix(rnorm(p), p, 1) + 2


results_ols <- data.frame()
results_Komi <- data.frame()
results_envelop <- data.frame()
results_envelop_intersec <- data.frame()


# for (r_fair in r_fair_levels) {
#   cat("Processing r_fair =", r_fair, "\n")
# 
#   for (rep in 1:num_replicates) {
#     # Generate data
#     Data_all <- data_generate(n, p, p_corr, p_corr_dim, p_S, S_Q, beta_SX, beta_S, beta_X, train_ratio = 0.8)
#     S_train <- Data_all$S_train
#     S_test <- Data_all$S_test
#     X_train <- Data_all$X_train
#     X_test <- Data_all$X_test
#     Y_train <- Data_all$Y_train
#     Y_test <- Data_all$Y_test
# 
#     OLS_result <- result_ols(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair)
#     Envelop_result <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = FALSE)
#     Envelop_intersec_result <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = TRUE)
# 
#     results_ols <- rbind(
#       results_ols,
#       data.frame(
#         r_fair = r_fair,
#         replicate = rep,
#         Fair_SX_train = OLS_result$Fair_SX_train,
#         MSE_train = OLS_result$MSE_train,
#         Fair_SX_test = OLS_result$Fair_SX_test,
#         MSE_test = OLS_result$MSE_test
#       )
#     )
# 
#     results_envelop <- rbind(
#       results_envelop,
#       data.frame(
#         r_fair = r_fair,
#         replicate = rep,
#         Fair_SX_train = Envelop_result$Fair_SX_train,
#         MSE_train = Envelop_result$MSE_train,
#         Fair_SX_test = Envelop_result$Fair_SX_test,
#         MSE_test = Envelop_result$MSE_test
#       )
#     )
# 
#     results_envelop_intersec <- rbind(
#       results_envelop_intersec,
#       data.frame(
#         r_fair = r_fair,
#         replicate = rep,
#         Fair_SX_train = Envelop_intersec_result$Fair_SX_train,
#         MSE_train = Envelop_intersec_result$MSE_train,
#         Fair_SX_test = Envelop_intersec_result$Fair_SX_test,
#         MSE_test = Envelop_intersec_result$MSE_test
#       )
#     )
#   }
# }
# # 
# # 
# # results_ols <- data.frame()
# # results_envelop <- data.frame()
# # results_envelop_intersec <- data.frame()
# 
# 
Process<-function(n, p, p_corr, p_corr_dim, p_S, S_Q, beta_SX, beta_S, beta_X,r_fair_levels){
  resultfinal<- data.frame()
  results_ols <- data.frame()
  results_Komi <- data.frame()
  results_envelop <- data.frame()
  results_envelop_intersec <- data.frame()
  for (r_fair in r_fair_levels) {
    cat("Processing r_fair =", r_fair, "\n")


    # Generate data
    Data_all <- data_generate(n, p, p_corr, p_corr_dim, p_S, S_Q, beta_SX, beta_S, beta_X, train_ratio = 0.8)
    S_train <- Data_all$S_train
    S_test <- Data_all$S_test
    X_train <- Data_all$X_train
    X_test <- Data_all$X_test
    Y_train <- Data_all$Y_train
    Y_test <- Data_all$Y_test

    OLS_result <- result_ols(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair)
    Komi_result <- result_Komi(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair)
    Envelop_result <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = FALSE)
    Envelop_intersec_result <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = TRUE)

    results_ols <- rbind(
      results_ols,
      data.frame(
        r_fair = r_fair,
        Fair_SX_train = OLS_result$Fair_SX_train,
        MSE_train = OLS_result$MSE_train,
        Fair_SX_test = OLS_result$Fair_SX_test,
        MSE_test = OLS_result$MSE_test,
        Method = "OLS"
      )
    )

    results_Komi <- rbind(
      results_Komi,
      data.frame(
        r_fair = r_fair,
        Fair_SX_train = Komi_result$Fair_SX_train,
        MSE_train = Komi_result$MSE_train,
        Fair_SX_test = Komi_result$Fair_SX_test,
        MSE_test = Komi_result$MSE_test,
        Method = "Komi"
      )
    )

    results_envelop <- rbind(
      results_envelop,
      data.frame(
        r_fair = r_fair,
        Fair_SX_train = Envelop_result$Fair_SX_train,
        MSE_train = Envelop_result$MSE_train,
        Fair_SX_test = Envelop_result$Fair_SX_test,
        MSE_test = Envelop_result$MSE_test,
        Method = "Envelop"
      )
    )

    results_envelop_intersec <- rbind(
      results_envelop_intersec,
      data.frame(
        r_fair = r_fair,
        Fair_SX_train = Envelop_intersec_result$Fair_SX_train,
        MSE_train = Envelop_intersec_result$MSE_train,
        Fair_SX_test = Envelop_intersec_result$Fair_SX_test,
        MSE_test = Envelop_intersec_result$MSE_test,
        Method = "Envelop_intersec"
      )
    )
    resultfinal<- rbind(results_ols,results_Komi,results_envelop,results_envelop_intersec)
  }
  return(resultfinal)

}
  
  
  
  
  
  
  
  
  
Processreal<-function(ratio,r_fair_levels){
  resultfinal<- data.frame()
  results_ols <- data.frame()
  results_Komi <- data.frame()
  results_envelop <- data.frame()
  results_envelop_intersec <- data.frame()
  data(communities.and.crime) # short-hand variable names.
  cc = communities.and.crime[complete.cases(communities.and.crime), ]
  id =sample(1:nrow(cc), size = round( ratio* nrow(cc))) ###select 30% of the dataset
  cc<-cc[id,]
  r = cc[, "ViolentCrimesPerPop"]
  s = cc[, c("racepctblack", "PctForeignBorn","MalePctDivorce","pctUrban","blackPerCap","FemalePctDiv")]
  p = cc[, setdiff(names(cc), c("ViolentCrimesPerPop", names(s),"county","state","fold","Num"))]
  r<-as.matrix(r)
  s<-as.matrix(s)
  p<-as.matrix(p)
  
  
  train_ratio = 0.8
  train_idx <- sample(1:nrow(p), size = round(train_ratio * nrow(p)))
  test_idx <- setdiff(1:nrow(p), train_idx)
  
  
  for (r_fair in r_fair_levels) {
    cat("Processing r_fair =", r_fair, "\n")
    
    
    # Import data
    
    Envelop_result<- result_envelop(s[train_idx,], p[train_idx,], r[train_idx], s[test_idx,], p[test_idx,], r[test_idx], r_fair, intersec = FALSE)
    Envelop_intersec_result <- result_envelop(s[train_idx,], p[train_idx,], r[train_idx], s[test_idx,], p[test_idx,], r[test_idx], r_fair, intersec = TRUE)
    OLS_result <- result_ols(s[train_idx,], p[train_idx,], r[train_idx], s[test_idx,], p[test_idx,], r[test_idx], r_fair)
    Komi_result <- result_Komi(s[train_idx,], p[train_idx,], r[train_idx], s[test_idx,], p[test_idx,], r[test_idx], r_fair)
    
    
    

    # OLS_result <- result_ols(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair)
    # Komi_result <- result_Komi(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair)
    # Envelop_result <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = FALSE)
    # Envelop_intersec_result <- result_envelop(S_train, X_train, Y_train, S_test, X_test, Y_test, r_fair, intersec = TRUE)
    
    results_ols <- rbind(
      results_ols,
      data.frame(
        r_fair = r_fair,
        Fair_SX_train = OLS_result$Fair_SX_train,
        MSE_train = OLS_result$MSE_train,
        Fair_SX_test = OLS_result$Fair_SX_test,
        MSE_test = OLS_result$MSE_test,
        Method = "OLS"
      )
    )
    
    results_Komi <- rbind(
      results_Komi,
      data.frame(
        r_fair = r_fair,
        Fair_SX_train = Komi_result$Fair_SX_train,
        MSE_train = Komi_result$MSE_train,
        Fair_SX_test = Komi_result$Fair_SX_test,
        MSE_test = Komi_result$MSE_test,
        Method = "Komi"
      )
    )
    
    results_envelop <- rbind(
      results_envelop,
      data.frame(
        r_fair = r_fair,
        Fair_SX_train = Envelop_result$Fair_SX_train,
        MSE_train = Envelop_result$MSE_train,
        Fair_SX_test = Envelop_result$Fair_SX_test,
        MSE_test = Envelop_result$MSE_test,
        Method = "Envelop"
      )
    )
    
    results_envelop_intersec <- rbind(
      results_envelop_intersec,
      data.frame(
        r_fair = r_fair,
        Fair_SX_train = Envelop_intersec_result$Fair_SX_train,
        MSE_train = Envelop_intersec_result$MSE_train,
        Fair_SX_test = Envelop_intersec_result$Fair_SX_test,
        MSE_test = Envelop_intersec_result$MSE_test,
        Method = "Envelop_intersec"
      )
    )
    resultfinal<- rbind(results_ols,results_Komi,results_envelop,results_envelop_intersec)
  }
  return(resultfinal)
  
}  
  
  
  
  