library(e1071)
library(ggplot2)
library(MASS)
library(tidyverse)
library(truncnorm)
library(stats)
library(clusterGeneration)


## gauss_cop
sampleCorrelationMatrix <- function(num_vars, eta) {
  cov_matrix <- genPositiveDefMat(
    dim = num_vars,
    covMethod = "c-vine",
    eta=eta
  )
  return(cov2cor(cov_matrix$Sigma))
}

bivariate_conditional_mean_and_samples <- function(X2_samples, R, n_samples) {
  # Extract the relevant components from the covariance matrix R
  R11 <- R[3, 3]               # The variance of the first variable
  R12 <- R[3, 1:2]             # The covariance between the first variable and the other two
  R22 <- R[1:2, 1:2]           # The covariance matrix of the other two variables

  # Compute the inverse of R22
  R22_inv <- solve(R22)

  # Calculate the conditional variance
  conditional_variance <- R11 - R12 %*% R22_inv %*% t(R12)

  # Calculate the conditional mean for each sample
  conditional_means <- X2_samples %*% t(R12 %*% R22_inv)

  # Generate samples from a Gaussian distribution with the computed conditional mean and variance
  generated_samples <- matrix(NA, nrow = nrow(X2_samples), ncol = n_samples)
  for (i in 1:nrow(X2_samples)) {
    generated_samples[i, ] <- rnorm(n_samples, mean = conditional_means[i], sd = sqrt(conditional_variance))
  }

  # Return the conditional means, conditional variance, and generated samples
  return(list(
    conditional_means = conditional_means,
    conditional_variance = conditional_variance,
    generated_samples = generated_samples
  ))
}


multivariate_conditional_mean_and_samples <- function(X2_samples, R) {
  # Determine the dimensions
  k <- ncol(X2_samples)       # Number of variables being conditioned on (X2)
  n_samples <- nrow(X2_samples)
  d <- nrow(R)                # Total number of variables

  X2_samples <- as.matrix(X2_samples)

  # print(paste0("Conditioning the distribution on index "), d)
  # Indices for partitioning the covariance matrix
  indices_1 <- d              # First variable (X1)
  indices_2 <- 1:(d-1)            # Remaining variables (X2)

  # Extract the relevant components from the covariance matrix R
  R11 <- R[indices_1, indices_1]         # Variance of X1
  R12 <- matrix(R[indices_1, indices_2], nrow=1)         # Covariance between X1 and X2
  R22 <- R[indices_2, indices_2]         # Covariance matrix of X2

  # Compute the inverse of R22
  R22_inv <- solve(R22)

  # Ensure all variables are numeric matrices
  R12 <- as.matrix(R12)
  R22_inv <- as.matrix(R22_inv)

  # Calculate the conditional variance
  conditional_variance <- R11 - R12 %*% R22_inv %*% t(R12)

  # Calculate the conditional mean for each sample
  conditional_means <- t(R12 %*% R22_inv %*% t(X2_samples))

  # Generate samples from a Gaussian distribution with the computed conditional mean and variance
  generated_samples <- matrix(NA, nrow = n_samples, ncol = 1)
  for (i in 1:nrow(X2_samples)) {
    generated_samples[i] <- rnorm(1, mean = conditional_means[i], sd = sqrt(conditional_variance))
  }

  # Return the conditional means, conditional variance, and generated samples
  return(list(
    generated_samples = generated_samples,
    conditional_means = conditional_means,
    conditional_variance = conditional_variance
  ))
}


## inference_utils
combined_empirical_transform <- function(data, flags) {

  # Function to compute the empirical CDF for continuous variables
  to_uniform_continuous <- function(x) {
    ecdf_fun <- ecdf(c(x, -1e9, 1e9))  # Empirical CDF with boundary adjustments
    transformed_x <- ecdf_fun(x)  # Uniform [0,1] values based on empirical CDF
    return(list(transformed_x = transformed_x, ecdf_fun = ecdf_fun))  # Return transformed data and ecdf_fun
  }

  # Function to reverse the continuous transformation (quantiles back to original values)
  from_uniform_continuous <- function(x, quantiles) {
    sorted_x <- sort(x)  # Sort the original variable
    n <- length(sorted_x)

    if (is.matrix(quantiles)) {
      return(apply(quantiles, 2, function(col) quantile(sorted_x, probs = col, type = 1)))
    } else {
      return(quantile(sorted_x, probs = quantiles, type = 1))
    }
  }

  # Function to compute the empirical CDF for binary variables
  binary_ecdf <- function(x) {
    p <- mean(x)  # Proportion of 1's

    transform_quantile <- function(val) {
      if (val == 0) {
        return(runif(1, min = 0, max = 1 - p))  # Sample between 0 and (1 - p)
      } else {
        return(runif(1, min = 1 - p, max = 1))  # Sample between (1 - p) and 1
      }
    }

    return(sapply(x, transform_quantile))  # Apply to each value
  }

  # Function to create a binary "CDF" function for forward transformation
  binary_cdf_fun <- function(p) {
    return(function(new_data) {
      # For binary variables, return corresponding deterministic quantiles based on empirical probability p
      sapply(new_data, function(val) {
        if (val == 0) {
          return(1 - p)  # 0 maps to quantile 1 - p
        } else if (val >= 1) {
          return(1)  # 1 maps to quantile p
        } else {
          # Return NA for non-binary values
          return(0)
        }
      })
    })
  }

  # Reverse transformation function for binary data
  reverse_binary_ecdf <- function(quantiles, p) {
    return(ifelse(quantiles < (1 - p), 0, 1))
  }

  # Function to reverse transform for a given variable based on its type
  reverse_transform <- function(column_index, quantiles) {
    is_binary <- flags[column_index]
    x <- data[[column_index]]  # Get the original data column

    if (is_binary == 1) {
      p <- mean(x)
      return(reverse_binary_ecdf(quantiles, p))
    } else {
      return(from_uniform_continuous(x, quantiles))
    }
  }

  # Store ecdf functions for continuous and binary columns
  ecdf_list <- list()

  # Transform all variables and store ecdf for both continuous and binary variables
  transformed_data <- as.data.frame(mapply(function(column, is_binary, idx) {
    if (is_binary == 1) {
      p <- mean(column)  # Proportion of 1's
      ecdf_list[[idx]] <<- binary_cdf_fun(p)  # Store the "binary CDF" function
      return(binary_ecdf(column))  # Apply binary transformation
    } else {
      result <- to_uniform_continuous(column)  # Apply continuous transformation
      ecdf_list[[idx]] <<- result$ecdf_fun  # Store the ecdf function for this column
      return(result$transformed_x)
    }
  }, data, flags, seq_along(data), SIMPLIFY = FALSE))

  forward_transform <- function(column_index, new_data) {
    is_binary <- flags[column_index]
    x <- data[[column_index]]  # Retrieve the original column data for binary case

    if (is_binary == 1) {
      # For binary variables, use the deterministic quantile based on the proportion of 1's
      cdf_fun <- ecdf_list[[column_index]]
      return(cdf_fun(new_data))  # Apply the stored binary CDF function
    } else {
      ecdf_fun <- ecdf_list[[column_index]]

      # Check if ecdf_fun is available for the specified column_index
      if (is.null(ecdf_fun)) {
        stop("No ECDF function available for the specified continuous variable.")
      }
      return(ecdf_fun(new_data))  # Apply ECDF for continuous variables
    }
  }

  # Return the transformed data, reverse transformation, and forward transformation function
  return(list(
    transformed_data = transformed_data,
    reverse_transform = function(column_index, quantiles) {
      return(reverse_transform(column_index, quantiles))
    },
    forward_transform = function(column_index, new_data) {
      return(forward_transform(column_index, new_data))
    },
    ecdf_list = ecdf_list  # Return the ecdf_list that now includes both continuous and binary cdfs
  ))
}

## synthetic_data_functions

generate_realistic_test_data <- function(
  n_samples,
  corr_matrix_trial,
  num_covariates,
  test_covariate_marginals, # Accepts `combined_empirical_transform` function
  treatment_family,
  prop_score_params,
  causal_effect_family,
  causal_effect_params,
  seed=None
) {
  if (!is.null(seed)) {
    set.seed(seed)
  }
  test_ranks <- pnorm(mvrnorm(n=n_samples, mu=rep(0, num_covariates+1), Sigma=corr_matrix_trial))
  Z_test <- test_ranks[, 1:num_covariates]


  for (d in 1:num_covariates) {
    Z_test[, d] <- test_covariate_marginals$reverse_transform(d, Z_test[, d])
  }
  if (treatment_family == 'bernoulli') {
    p <- sigmoid(
      cbind(rep(1, dim(Z_test)[1]), Z_test) %*% matrix(prop_score_params)
    )
    X_test <- rbinom(n_samples, 1, p)
  } else if (treatment_family == 'gaussian') {
    len_prop_score_params <- length(prop_score_params)
    mu <- cbind(rep(1, dim(Z_test)[1]), Z_test) %*% matrix(prop_score_params[1:(len_prop_score_params-1)])
    sigma <- prop_score_params[len_prop_score_params]
    X_test <- rnorm(n_samples, mean=mu, sd=sigma)
  } else if (treatment_family == 'truncated-gaussian') {
    len_prop_score_params <- length(prop_score_params)
    mu <- cbind(rep(1, dim(Z_test)[1]), Z_test) %*% matrix(prop_score_params[1:(len_prop_score_params-1)])
    sigma <- prop_score_params[len_prop_score_params]
    X_test <- rtruncnorm(a=0, n=n_samples, mean=mu, sd=sigma)
  }

  if (causal_effect_family == 'gaussian') {
    Y_test <- qnorm(
      test_ranks[, num_covariates + 1],
      mean=causal_effect_params[1] + X_test * causal_effect_params[2], sd=causal_effect_params[3]
    )
  } else if (causal_effect_family == 'exp') {
    rate <- exp(causal_effect_params[1] + X_test * causal_effect_params[2])
    Y_test <- qexp(
      test_ranks[, num_covariates + 1],
      rate=rate
    )
  } else if (causal_effect_family == 'gamma') {
    shape <- 1 / causal_effect_params[3]
    scale <- exp(causal_effect_params[1] + X_test * causal_effect_params[2]) / shape
    Y_test <- qgamma(
      test_ranks[, num_covariates + 1],
      shape=shape,
      scale=scale
    )
  } else if (causal_effect_family == 'bernoulli') {
    beta_x <- causal_effect_params[1] + X_test * causal_effect_params[2]
    p <- 1 / (exp(-1 * beta_x) + 1)  # Sigmoid transformation
    Y_test <- qbinom(
      test_ranks[, num_covariates + 1],
      size=1,
      prob=p
    )
  }
  return(list(
    Z_test=Z_test,
    X_test=X_test,
    Y_test=Y_test,
    test_covariate_ranks=test_ranks[, 1:num_covariates],
    causal_effect_ranks=test_ranks[, num_covariates + 1]
  ))
}

generate_realistic_train_data <- function(
    n_samples,
    corr_matrix_trial,
    num_covariates,
    train_covariate_marginals,
    test_covariate_marginals,
    covariate_type, # 0 if cts, 1 if binary
    treatment_family,
    prop_score_params,
    causal_effect_family,
    causal_effect_params,
    seed=None
) {
  if (!is.null(seed)) {
    set.seed(seed)
  }
  base_train_ranks <- pnorm(mvrnorm(n=n_samples, mu=rep(0, num_covariates+1), Sigma=corr_matrix_trial))
  Z_train <- base_train_ranks[, 1:num_covariates]
  rescaled_train_covariate_ranks <- matrix(0, nrow=dim(Z_train)[1], ncol=dim(Z_train)[2])
  for (d in 1:num_covariates) {
    Z_train[, d] <- train_covariate_marginals$reverse_transform(d, Z_train[, d])
    if (covariate_type[d] == 1) {
      qbase <- base_train_ranks[, d]
      qtrain_Xtrain <- train_covariate_marginals$forward_transform(d, Z_train[, d] - 1)
      qtrain_Xtrainp1 <- train_covariate_marginals$forward_transform(d, Z_train[, d])
      qtest_Xtrain <- test_covariate_marginals$forward_transform(d, Z_train[, d] - 1)
      qtest_Xtrainp1 <- test_covariate_marginals$forward_transform(d, Z_train[, d])
      rescaled_train_covariate_ranks[, d] <- qtest_Xtrain + (
        (qtest_Xtrainp1 - qtest_Xtrain) *
          (qbase  - qtrain_Xtrain) / (qtrain_Xtrainp1 - qtrain_Xtrain)
      )
    } else if (covariate_type[d] == 0) {
      rescaled_train_covariate_ranks[, d] <- test_covariate_marginals$forward_transform(d, Z_train[, d])
    }
  }
  rescaled_train_covariate_ranks_gauss <- qnorm(rescaled_train_covariate_ranks)
  ## Sample test ranks for outcome
  train_causal_effect_ranks <- pnorm(
    multivariate_conditional_mean_and_samples(
      rescaled_train_covariate_ranks_gauss,
      corr_matrix_trial
    )$generated_samples
  )
  if (treatment_family == 'bernoulli') {
    p <- sigmoid(
      cbind(rep(1, dim(Z_train)[1]), Z_train) %*% matrix(prop_score_params)
    )
    X_train <- rbinom(n_samples, 1, p)
  } else if (treatment_family == 'gaussian') {
    len_prop_score_params <- length(prop_score_params)
    mu <- cbind(rep(1, dim(Z_train)[1]), Z_train) %*% matrix(prop_score_params[1:(len_prop_score_params-1)])
    sigma <- prop_score_params[len_prop_score_params]
    X_train <- rnorm(n_samples, mean=mu, sd=sigma)
  } else if (treatment_family == 'truncated-gaussian') {
    len_prop_score_params <- length(prop_score_params)
    mu <- cbind(rep(1, dim(Z_test)[1]), Z_train) %*% matrix(prop_score_params[1:(len_prop_score_params-1)])
    sigma <- prop_score_params[len_prop_score_params]
    X_train <- rtruncnorm(a=0, n=n_samples, mean=mu, sd=sigma)
  }
  if (causal_effect_family == 'gaussian') {
    Y_train <- qnorm(
      train_causal_effect_ranks,
      mean=causal_effect_params[1] + X_train * causal_effect_params[2], sd=causal_effect_params[3]
    )
  } else if (causal_effect_family == 'gamma') {
    shape <- 1 / causal_effect_params[3]
    scale <- exp(causal_effect_params[1] + X_train * causal_effect_params[2]) / shape
    Y_train <- qgamma(
      train_causal_effect_ranks,
      shape=shape,
      scale=scale
    )
  } else if (causal_effect_family == 'exp') {
    rate <- exp(causal_effect_params[1] + X_train * causal_effect_params[2])
    Y_train <- qexp(
      train_causal_effect_ranks,
      rate=rate
    )
  } else if (causal_effect_family == 'bernoulli') {
    beta_x <- causal_effect_params[1] + X_train * causal_effect_params[2]
    p <- 1 / (exp(-1 * beta_x) + 1) # Sigmoid transformation
    Y_train <- qbinom(
      train_causal_effect_ranks,
      size=1,
      prob=p
    )
  }
  return(list(
    Z_train=Z_train,
    X_train=X_train,
    Y_train=Y_train,
    train_ranks=cbind(rescaled_train_covariate_ranks, train_causal_effect_ranks),
    train_causal_effect_ranks=train_causal_effect_ranks
  ))
}



experiments_generate_realistic_data <- function(
    n_samples_test,
    n_samples_train,
    prop_score_params,
    Y_test,
    X_test,
    Z_test,
    Z_train,
    cov_datatype,
    marginal_cdf_seed=None,
    training_data_seed=None,
    test_data_seed=None
) {
  d <- ncol(Z_test) # Dim of covariates
  n <- nrow(Z_test) # Length of dataset

  if (!is.null(marginal_cdf_seed)) {
    set.seed(marginal_cdf_seed)
  }
  # Calculate Gaussian correlation between covariates
  test_marginal_cdfs <- combined_empirical_transform(Z_test, cov_datatype)
  train_marginal_cdfs <- combined_empirical_transform(Z_train, cov_datatype)

  # Estimate params of causal effect
  regression_data <- data.frame(Y=Y_test, X=X_test)
  glm_result <- fit_gamma_glm(regression_data)
  cdf_values <- glm_result$cdf_values
  lin_coeffs <- glm_result$lin_coeffs
  dispersion <- glm_result$dispersion

  # Estimating full correlation matrix with covariates and outcome
  std_gauss_Z <- apply(test_marginal_cdfs$transformed_data, 2, qnorm)
  full_gaussian_ranks <- cbind(std_gauss_Z, qnorm(cdf_values))
  full_gaussian_copula_corr <- cor(full_gaussian_ranks)

  test_data <- generate_realistic_test_data(
    n_samples=n_samples_test,
    corr_matrix_trial=full_gaussian_copula_corr,
    num_covariates=d,
    test_covariate_marginals=test_marginal_cdfs,
    treatment_family='bernoulli',
    prop_score_params=prop_score_params, # For randomised bernoulli treatments
    causal_effect_family='gamma',
    causal_effect_params=c(lin_coeffs[1], lin_coeffs[2], dispersion),
    seed=test_data_seed
  )

  # Simulate train Data
  train_data <- generate_realistic_train_data(
    n_samples=n_samples_train,
    corr_matrix_trial=full_gaussian_copula_corr,
    num_covariates=d,
    train_covariate_marginals=train_marginal_cdfs,
    test_covariate_marginals=test_marginal_cdfs,
    covariate_type=cov_datatype,
    treatment_family='bernoulli',
    prop_score_params=prop_score_params, # For randomised bernoulli treatments
    causal_effect_family='gamma',
    causal_effect_params=c(lin_coeffs[1], lin_coeffs[2], dispersion),
    seed=training_data_seed
  )
  return(list(
    train_data=train_data,
    test_data=test_data,
    causal_effect_lin_params=c(lin_coeffs[1], lin_coeffs[2]),
    causal_effect_dispersion=dispersion
  ))
}


# # Example usage:
#
# # Create a dataset with both continuous and binary variables
# set.seed(123)
# data <- data.frame(
#   X1 = rnorm(1000),  # Continuous variable (normal)
#   X2 = runif(1000),  # Continuous variable (uniform)
#   X3 = rexp(1000),   # Continuous variable (exponential)
#   B1 = rbinom(1000, 1, 0.2),  # Binary variable
#   B2 = rbinom(1000, 1, 0.5)   # Binary variable
# )
#
# # Define the flags: 0 for continuous, 1 for binary
# flags <- c(0, 0, 0, 1, 1)
#
# # Apply the combined transformation function
# result <- combined_empirical_transform(data, flags)
#
# # Access the transformed data
# transformed_data <- result$transformed_data
# print(head(transformed_data))
#
# # Reverse transformation example:
# # Convert a set of quantiles back to the original domain for continuous 'X1'
# quantiles_X1 <- transformed_data$X1[1:100]  # Example quantiles from transformed data
# original_values_X1 <- result$reverse_transform(1, quantiles_X1)  # Column index for 'X1' is 1
# print(original_values_X1)
#
# # Reverse transformation example for binary 'B1'
# # quantiles_B1 <- transformed_data$B1[1:5]  # Example quantiles from transformed data
# original_values_B1 <- result$reverse_transform(4, c(0.1, 0.5, 0.9))  # Column index for 'B1' is 4
# print(original_values_B1)


run_copula_inference <- function(stan_data_list, stan_model, matrix_size, iter_warmup=500, iter_sampling=500) {
  bayesian_copula_fit <- stan_model$sample(
    data = stan_data_list,
    seed=123,
    refresh=5,
    iter_warmup = iter_warmup,
    iter_sampling = iter_sampling,
    chains=4,
    parallel_chains = 4
  )
  # Get the fit summary
  fit_summary <- bayesian_copula_fit$summary()

  # Step 1: Extract rho matrix
  rho_summary <- fit_summary %>%
    filter(str_detect(variable, 'Rho') & !str_detect(variable, 'chol'))

  # Create the rho matrix from the summary
  create_rho_matrix <- function(rho_summary, matrix_size) {
    rho_summary <- rho_summary %>%
      mutate(row = as.numeric(str_extract(variable, "(?<=\\[)\\d+")),
             col = as.numeric(str_extract(variable, "(?<=,)\\d+(?=\\])")))

    rho_matrix <- matrix(NA, nrow = matrix_size, ncol = matrix_size)

    for (i in 1:nrow(rho_summary)) {
      row <- rho_summary$row[i]
      col <- rho_summary$col[i]
      rho_matrix[row, col] <- rho_summary$mean[i]
    }

    return(rho_matrix)
  }

  rho_matrix <- create_rho_matrix(rho_summary, matrix_size)

  # Step 2: Extract linear terms (betas)
  linear_terms <- fit_summary %>%
    filter(str_detect(variable, 'betas')) %>%
    select(mean) %>%
    as_vector()

  # Step 3: Extract phi
  phi <- fit_summary %>%
    filter(str_detect(variable, 'phi')) %>%
    select(mean) %>%
    as_vector()

  # Return the rho matrix, linear terms, and phi
  return(list(
    rho_matrix = rho_matrix,
    linear_terms = linear_terms,
    phi = phi,
    fit_summary = fit_summary
  ))
}


# Function to fit the GLM and calculate CDF values, coefficients, and dispersion
fit_gamma_glm <- function(dat) {
  # Fit the Gamma GLM with a log link
  model <- glm(Y ~ X, family = Gamma(link = 'log'), data = dat)

  # Get the model summary to extract the dispersion parameter
  model_summary <- summary(model)

  # Extract linear coefficients and dispersion parameter
  lin_coeffs <- model$coefficients
  dispersion <- model_summary$dispersion

  # Calculate the mean using the linear predictors
  mean <- exp(lin_coeffs[1] + lin_coeffs[2] * dat$X)

  # Calculate the shape and scale parameters
  gamma_shape <- 1 / dispersion
  gamma_scale <- mean * dispersion

  # Calculate the CDF values for the actual Y values
  cdf_values <- pgamma(dat$Y, shape = gamma_shape, scale = gamma_scale)

  # Return the CDF values, coefficients, and dispersion
  return(list(cdf_values = cdf_values, lin_coeffs = lin_coeffs, dispersion = dispersion))
}

# Function to calculate shape and scale parameters from linear coefficients and dispersion
calculate_gamma_params <- function(lin_coeffs, dispersion, X) {
  # Calculate the mean using the linear predictors for the provided X values
  mean <- exp(lin_coeffs[1] + lin_coeffs[2] * X)

  # Calculate the shape and scale parameters
  gamma_shape <- 1 / dispersion
  gamma_scale <- mean * dispersion

  return(list(gamma_shape = gamma_shape, gamma_scale = gamma_scale))
}

