library(clusterGeneration)
library(e1071)
library(ggplot2)
library(MASS)
library(tidyverse)
library(truncnorm)
library(stats)


sampleCorrelationMatrix <- function(num_vars, eta) {
  cov_matrix <- genPositiveDefMat(
    dim = num_vars,
    covMethod = "c-vine",
    eta=eta
  )
  return(cov2cor(cov_matrix$Sigma))
}

bivariate_conditional_mean_and_samples <- function(X2_samples, R, n_samples) {
  # Extract the relevant components from the covariance matrix R
  R11 <- R[3, 3]               # The variance of the first variable
  R12 <- R[3, 1:2]             # The covariance between the first variable and the other two
  R22 <- R[1:2, 1:2]           # The covariance matrix of the other two variables

  # Compute the inverse of R22
  R22_inv <- solve(R22)

  # Calculate the conditional variance
  conditional_variance <- R11 - R12 %*% R22_inv %*% t(R12)

  # Calculate the conditional mean for each sample
  conditional_means <- X2_samples %*% t(R12 %*% R22_inv)

  # Generate samples from a Gaussian distribution with the computed conditional mean and variance
  generated_samples <- matrix(NA, nrow = nrow(X2_samples), ncol = n_samples)
  for (i in 1:nrow(X2_samples)) {
    generated_samples[i, ] <- rnorm(n_samples, mean = conditional_means[i], sd = sqrt(conditional_variance))
  }

  # Return the conditional means, conditional variance, and generated samples
  return(list(
    conditional_means = conditional_means,
    conditional_variance = conditional_variance,
    generated_samples = generated_samples
  ))
}


multivariate_conditional_mean_and_samples <- function(X2_samples, R) {
  # Determine the dimensions
  k <- ncol(X2_samples)       # Number of variables being conditioned on (X2)
  n_samples <- nrow(X2_samples)
  d <- nrow(R)                # Total number of variables

  X2_samples <- as.matrix(X2_samples)

  # print(paste0("Conditioning the distribution on index "), d)
  # Indices for partitioning the covariance matrix
  indices_1 <- d              # First variable (X1)
  indices_2 <- 1:(d-1)            # Remaining variables (X2)

  # Extract the relevant components from the covariance matrix R
  R11 <- R[indices_1, indices_1]         # Variance of X1
  R12 <- matrix(R[indices_1, indices_2], nrow=1)         # Covariance between X1 and X2
  R22 <- R[indices_2, indices_2]         # Covariance matrix of X2

  # Compute the inverse of R22
  R22_inv <- solve(R22)

  # Ensure all variables are numeric matrices
  R12 <- as.matrix(R12)
  R22_inv <- as.matrix(R22_inv)

  # Calculate the conditional variance
  conditional_variance <- R11 - R12 %*% R22_inv %*% t(R12)

  # Calculate the conditional mean for each sample
  conditional_means <- t(R12 %*% R22_inv %*% t(X2_samples))

  # Generate samples from a Gaussian distribution with the computed conditional mean and variance
  generated_samples <- matrix(NA, nrow = n_samples, ncol = 1)
  for (i in 1:nrow(X2_samples)) {
    generated_samples[i] <- rnorm(1, mean = conditional_means[i], sd = sqrt(conditional_variance))
  }

  # Return the conditional means, conditional variance, and generated samples
  return(list(
    generated_samples = generated_samples,
    conditional_means = conditional_means,
    conditional_variance = conditional_variance
  ))
}


generate_test_data <- function(
  n_samples,
  corr_matrix_trial,
  num_cont_covariates,
  num_disc_covariates,
  test_cont_margin_params,
  test_cont_margin_family,
  test_disc_margin_params,
  test_disc_margin_family,
  treatment_family,
  prop_score_params,
  causal_effect_family,
  causal_effect_params
) {
  num_covariates <- num_disc_covariates + num_cont_covariates
  test_ranks <- pnorm(mvrnorm(n=n_samples, mu=rep(0, num_covariates+1), Sigma=corr_matrix_trial))
  Z_test <- test_ranks[, 1:num_covariates]

  if (num_cont_covariates > 0) {
    if (test_cont_margin_family == 'gamma') {
      Z_test[, 1:num_cont_covariates] <- qgamma(
        test_ranks[, 1:num_cont_covariates],
        shape=test_cont_margin_params[1],
        rate=test_cont_margin_params[2]
      )
    } else if (test_cont_margin_family == 'gaussian') {
      Z_test[, 1:num_cont_covariates] <- qnorm(
        test_ranks[, 1:num_cont_covariates],
        mean=test_cont_margin_params[1],
        sd=test_cont_margin_params[2]
      )
    }
  }
  if (num_disc_covariates > 0) {
    if (test_disc_margin_family == 'bernoulli') {
      Z_test[, (num_cont_covariates+1):num_covariates] <- qbinom(
        test_ranks[, (num_cont_covariates+1):num_covariates],
        size=test_disc_margin_params[1],
        prob=test_disc_margin_params[2]
      )
    }
  }
  if (treatment_family == 'bernoulli') {
    p <- sigmoid(
      cbind(rep(1, dim(Z_test)[1]), Z_test) %*% matrix(prop_score_params)
    )
    X_test <- rbinom(n_samples, 1, p)
  } else if (treatment_family == 'gaussian') {
    len_prop_score_params <- length(prop_score_params)
    mu <- cbind(rep(1, dim(Z_test)[1]), Z_test) %*% matrix(prop_score_params[1:(len_prop_score_params-1)])
    sigma <- prop_score_params[len_prop_score_params]
    X_test <- rnorm(n_samples, mean=mu, sd=sigma)
  } else if (treatment_family == 'truncated-gaussian') {
    len_prop_score_params <- length(prop_score_params)
    mu <- cbind(rep(1, dim(Z_test)[1]), Z_test) %*% matrix(prop_score_params[1:(len_prop_score_params-1)])
    sigma <- prop_score_params[len_prop_score_params]
    X_test <- rtruncnorm(a=0, n=n_samples, mean=mu, sd=sigma)
  }

  if (causal_effect_family == 'gaussian') {
    Y_test <- qnorm(
      test_ranks[, num_covariates + 1],
      mean=causal_effect_params[1] + X_test * causal_effect_params[2], sd=causal_effect_params[3]
    )
  } else if (causal_effect_family == 'exp') {
    rate <- exp(causal_effect_params[1] + X_test * causal_effect_params[2])
    Y_test <- qexp(
      test_ranks[, num_covariates + 1],
      rate=rate
    )
  } else if (causal_effect_family == 'gamma') {
    shape <- 1 / causal_effect_params[3]
    scale <- exp(causal_effect_params[1] + X_test * causal_effect_params[2]) / shape
    Y_test <- qgamma(
      test_ranks[, num_covariates + 1],
      shape=shape,
      scale=scale
    )
  } else if (causal_effect_family == 'bernoulli') {
    beta_x <- causal_effect_params[1] + X_test * causal_effect_params[2]
    p <- 1 / (exp(-1 * beta_x) + 1)  # Sigmoid transformation
    Y_test <- qbinom(
      test_ranks[, num_covariates + 1],
      size=1,
      prob=p
    )
  }
  return(list(
    Z_test=Z_test, X_test=X_test, Y_test=Y_test, test_covariate_ranks=test_ranks[, 1:num_covariates]
  ))
}


generate_train_data <- function(
    n_samples,
    corr_matrix_trial,
    num_cont_covariates,
    num_disc_covariates,
    train_cont_margin_params,
    train_cont_margin_family,
    train_disc_margin_params,
    train_disc_margin_family,
    test_cont_margin_params,
    test_cont_margin_family,
    test_disc_margin_params,
    test_disc_margin_family,
    treatment_family,
    prop_score_params,
    causal_effect_family,
    causal_effect_params
) {
  num_covariates <- num_disc_covariates + num_cont_covariates
  base_train_ranks <- pnorm(mvrnorm(n=n_samples, mu=rep(0, num_covariates+1), Sigma=corr_matrix_trial))
  Z_train <- base_train_ranks[, 1:num_covariates]
  rescaled_train_covariate_ranks <- matrix(0, nrow=dim(Z_train)[1], ncol=dim(Z_train)[2])
  if (num_cont_covariates > 0) {
    if (train_cont_margin_family == 'gamma') {
      Z_train[, 1:num_cont_covariates] <- qgamma(
        base_train_ranks[, 1:num_cont_covariates], shape=train_cont_margin_params[1], rate=train_cont_margin_params[2]
      )
      if (test_cont_margin_family == 'gamma') {
        rescaled_train_covariate_ranks[, 1:num_cont_covariates] <- pgamma(
          Z_train[, 1:num_cont_covariates],
          shape=test_cont_margin_params[1],
          rate=test_cont_margin_params[2]
        )
      } else if (test_cont_margin_family == 'gaussian') {
        rescaled_train_covariate_ranks[, 1:num_cont_covariates] <- pnorm(
          Z_train[, 1:num_cont_covariates],
          mean=test_cont_margin_params[1],
          sd=test_cont_margin_params[2]
        )
      }
    } else if (train_cont_margin_family == 'gaussian') {
      Z_train[, 1:num_cont_covariates] <- qnorm(
        base_train_ranks[, 1:num_cont_covariates], mean=train_cont_margin_params[1], sd=train_cont_margin_params[2]
      )
      if (test_cont_margin_family == 'gamma') {
        rescaled_train_covariate_ranks[, 1:num_cont_covariates] <- pgamma(
          Z_train[, 1:num_cont_covariates],
          shape=test_cont_margin_params[1],
          rate=test_cont_margin_params[2]
        )
      } else if (test_cont_margin_family == 'gaussian') {
        rescaled_train_covariate_ranks[, 1:num_cont_covariates] <- pnorm(
          Z_train[, 1:num_cont_covariates],
          mean=test_cont_margin_params[1],
          sd=test_cont_margin_params[2]
        )
      }
    }
  }
  if (num_disc_covariates > 0) {
    if (train_disc_margin_family == 'bernoulli') {
      Z_disc <- qbinom(
        base_train_ranks[, (num_cont_covariates+1):num_covariates],
        size=train_disc_margin_params[1],
        prob=train_disc_margin_params[2]
      )
      Z_train[, (num_cont_covariates+1):num_covariates] <- Z_disc
    }
    if (test_disc_margin_family == 'bernoulli') {
      qbase <- base_train_ranks[, (num_cont_covariates+1):num_covariates]
      qtrain_Xtrain <- pbinom(Z_disc-1, size=train_disc_margin_params[1], prob = train_disc_margin_params[2])
      qtrain_Xtrainp1 <- pbinom(Z_disc, size=train_disc_margin_params[1], prob = train_disc_margin_params[2])
      qtest_Xtrain <- pbinom(Z_disc-1, size=test_disc_margin_params[1], prob = test_disc_margin_params[2])
      qtest_Xtrainp1 <- pbinom(Z_disc, size=test_disc_margin_params[1], prob = test_disc_margin_params[2])
      rescaled_test_disc_ranks <- qtest_Xtrain + (
        (qtest_Xtrainp1 - qtest_Xtrain) *
          (qbase  - qtrain_Xtrain) / (qtrain_Xtrainp1 - qtrain_Xtrain)
      )
    }
    rescaled_train_covariate_ranks[, (num_cont_covariates+1):num_covariates] <- rescaled_test_disc_ranks
  }
  rescaled_train_covariate_ranks_gauss <- qnorm(rescaled_train_covariate_ranks)
  ## Sample test ranks for outcome
  train_causal_effect_ranks <- pnorm(
    multivariate_conditional_mean_and_samples(
      rescaled_train_covariate_ranks_gauss,
      corr_matrix_trial
    )$generated_samples
  )
  if (treatment_family == 'bernoulli') {
    p <- sigmoid(
      cbind(rep(1, dim(Z_train)[1]), Z_train) %*% matrix(prop_score_params)
    )
    X_train <- rbinom(n_samples, 1, p)
  } else if (treatment_family == 'gaussian') {
    len_prop_score_params <- length(prop_score_params)
    mu <- cbind(rep(1, dim(Z_train)[1]), Z_train) %*% matrix(prop_score_params[1:(len_prop_score_params-1)])
    sigma <- prop_score_params[len_prop_score_params]
    X_train <- rnorm(n_samples, mean=mu, sd=sigma)
  } else if (treatment_family == 'truncated-gaussian') {
    len_prop_score_params <- length(prop_score_params)
    mu <- cbind(rep(1, dim(Z_test)[1]), Z_train) %*% matrix(prop_score_params[1:(len_prop_score_params-1)])
    sigma <- prop_score_params[len_prop_score_params]
    X_train <- rtruncnorm(a=0, n=n_samples, mean=mu, sd=sigma)
  }
  if (causal_effect_family == 'gaussian') {
    Y_train <- qnorm(
      train_causal_effect_ranks,
      mean=causal_effect_params[1] + X_train * causal_effect_params[2], sd=causal_effect_params[3]
    )
  } else if (causal_effect_family == 'gamma') {
    shape <- 1 / causal_effect_params[3]
    scale <- exp(causal_effect_params[1] + X_train * causal_effect_params[2]) / shape
    Y_train <- qgamma(
      train_causal_effect_ranks,
      shape=shape,
      scale=scale
    )
  } else if (causal_effect_family == 'exp') {
    rate <- exp(causal_effect_params[1] + X_train * causal_effect_params[2])
    Y_train <- qexp(
      train_causal_effect_ranks,
      rate=rate
    )
  } else if (causal_effect_family == 'bernoulli') {
    beta_x <- causal_effect_params[1] + X_train * causal_effect_params[2]
    p <- 1 / (exp(-1 * beta_x) + 1) # Sigmoid transformation
    Y_train <- qbinom(
      train_causal_effect_ranks,
      size=1,
      prob=p
    )
  }
  return(list(
    Z_train=Z_train,
    X_train=X_train,
    Y_train=Y_train,
    train_ranks=cbind(rescaled_train_covariate_ranks, train_causal_effect_ranks)
  ))
}