set.seed(1001)

## setting
n1 <- 400 # sample size
m <- 5 # number of subgroups
num_stage <- 14
tau <- c(-2.769924, 10.531104, -1.212645, 10.886470, -1.458249)
true_p <- c(0.2788462, 0.1250000, 0.2980769, 0.1121795, 0.1858974)
names(true_p) <- LETTERS[seq(1:m)]
mu1_vec <- c(42.56808, 50.44094, 44.36636, 44.30218, 37.70726)
mu0_vec <- c(45.33801, 39.90983, 45.57901, 33.41571, 39.16550)
sd1_vec <- c(10.84663, 12.29309, 12.64410, 14.28020, 14.64300)
sd0_vec <- c(11.49844, 15.18284, 14.56797, 13.08778, 15.06209)

## Generate covariates
GenX <- function(n, m, sampling_prob) {
  X <- sample(1:m, size = n, replace = TRUE, prob = sampling_prob)
  return(X)
}

## Generate subgroup membership
GenS <- function(X, m) {
  S <- matrix(NA, nrow = length(X), ncol = m)
  for (j in 1:m) {
    S[,j] <- (X == j)
  }
  colnames(S) <- LETTERS[seq(1:m)]
  return(S)
}

## Generate outcomes
GenY <- function(n, Tr, S, mu1_vec, mu0_vec, sd1_vec, sd0_vec) {
  Y <- NULL
  for (i in 1:n) {
    # treatment arm
    if (Tr[i] == 1) {
      idx <- which(S[i,] == 1)
      Y[i] <- rnorm(1, mu1_vec[idx], sd1_vec[idx])
    }
    # control arm
    if (Tr[i] == 0) {
      idx <- which(S[i,] == 1)
      Y[i] <- rnorm(1, mu0_vec[idx], sd0_vec[idx])
    }
  }
  return(Y)
}

# Generate covariates
X_1 <- GenX(n1, m, true_p)

# Generate subgroup memberships
S_1 <- GenS(X_1, m)

# Assign treatment randomly
T_1 <- rbinom(n1, 1, 0.5)

# Generate outcomes
Y_1 <- GenY(n1, T_1, S_1, mu1_vec, mu0_vec, sd1_vec, sd0_vec)

# Randomly assign treatments
e_1 <- rep(1/2, m)

## Estimation
# Estimated subgroup proportions
p_1 <- colSums(S_1) / n1

# Estimate propensity scores
e_1.hat <- NULL
for (k in 1:ncol(S_1)) {
  e_1.hat[k] <- sum(T_1[S_1[,k]]) / sum(S_1[,k])
}

# Estimate subgroup ATEs
tau_1 <- sd_1.t <- sd_1.c <- NULL
for (j in 1:ncol(S_1)) {
  dat1 <- as.data.frame(cbind(Y_1[S_1[,j]], T_1[S_1[,j]], X_1[S_1[,j]]))
  names(dat1) <- c("Y", "Tr", "X")
  
  tau_1[j] <- mean(dat1$Y * dat1$Tr / e_1.hat[j] - dat1$Y * (1 - dat1$Tr) / (1 - e_1.hat[j]))
  sd_1.t[j] <- sd(dat1$Y[dat1$Tr == 1])
  sd_1.c[j] <- sd(dat1$Y[dat1$Tr == 0])
}

# Name subgroup ATEs
names(tau_1) <- LETTERS[seq(1:m)]
# Name subgroup SDs
names(sd_1.t) <- names(sd_1.c) <- LETTERS[seq(1:m)]

tau_old <- tau_1
sd_old.t <- sd_1.t
sd_old.c <- sd_1.c
S_old <- S_1
n_old <- n1
T_old <- T_1
X_old <- X_1
Y_old <- Y_1
e_1.hat_old <- e_1.hat
p_old <- p_1
var_opt <- 1 / p_old * (sd_old.t^2 / e_1.hat_old + sd_old.c^2 / (1 - e_1.hat_old))
tau_opt <- tau_old

nk <- round(true_p * n1)
s <- var_opt
sigm <- diag(s) / n1
theta <- tau_old
s <- s / n1
thetahh <- c(theta)
r <- mean(s) / var(thetahh)
sig <- sqrt(sum(s * nk) / m)
d <- 1 / 4

# Create a table to store thetah for different Delta values
Delta_values <- seq(0, 0.1, by = 0.01)
results <- data.frame(Delta = Delta_values, matrix(NA, nrow = length(Delta_values), ncol = m))
colnames(results)[2:(m+1)] <- paste0("thetah_", LETTERS[seq(1:m)])

# Compute thetah for each Delta
for (i in 1:length(Delta_values)) {
  Delta <- Delta_values[i]
  temp <- r * (sig / sqrt(mean(s)))^(2 * Delta)
  tri <- min(c(1, temp))
  thetah <- tri * mean(thetahh) + (1 - tri) * thetahh
  results[i, 2:(m+1)] <- thetah
}

# Display the results
results
