import dte_adj
from dte_adj.plot import plot
import numpy as np
import matplotlib.pyplot as plt
import importlib
from collections import defaultdict
from sklearn.linear_model import LinearRegression, LogisticRegression
import tqdm
import time
import warnings
import xgboost as xgb
import pandas as pd

def seed(i=123):
  import torch
  torch.manual_seed(i)
  import random
  random.seed(i)
  import numpy as np
  np.random.seed(i)



def generate_data(n=1000, S=4, discrete=False):    
    # Generate W ~ U(0,1)
    W = np.random.uniform(0, 1, n)
    
    # Assign strata based on W
    strata = np.digitize(W, np.linspace(0, 1, S + 1)[1:])
    
    # Generate X ~ N(0, I_20)
    X = np.random.randn(n, 20)
    
    # Treatment assignment Z ~ Bernoulli(0.5) within each stratum
    Z = np.zeros(n)
    for s in range(S):
        indices = np.where(strata == s)[0]
        Z[indices] = np.random.binomial(1, 0.5, size=len(indices))
    
    # Define functions b(X, W) and c(X, W)
    def b(X, W):
        return (
            np.sin(np.pi * X[:, 0] * X[:, 1]) +
            2 * (X[:, 2] - 0.5) ** 2 +
            X[:, 3] +
            0.5 * X[:, 4] +
            0.1 * W
        )
    
    def c(X, W):
        return 0.1 * (X[:, 0] + np.log(1 + np.exp(X[:, 1])) + W)
    
    # Define parameters
    a1, a0 = 2, 1
    b1, b0 = 1, -1
    c1, c0 = 3, 3
    
    # Generate errors
    epsilon = np.random.randn(n)
    
    # Compute Y(d)
    Y0 = a0 + b(X, W) + epsilon
    Y1 = a1 + b(X, W) + epsilon
    
    # Compute D(0) and D(1)
    D0 = (b0 + c(X, W) > c0 * epsilon).astype(int)
    D1 = np.where(D0 == 0, (b1 + c(X, W) > c1 * epsilon).astype(int), 1)
    
    # Compute observed D and Y
    D = D1 * Z + D0 * (1 - Z)
    Y = Y1 * D + Y0 * (1 - D)
    if discrete:
        Y = np.random.poisson(np.abs(Y))
    
    return {
        'W': W, 'X': X, 'Z': Z, 'D': D, 'Y': Y,
        'D0': D0, 'D1': D1, 'Y0': Y0, 'Y1': Y1, 'strata': strata
    }

# Generate true DTE/ continuous
seed(123)
test_date = generate_data(10**6)
X_test, D_test, Y_test, Z_test, S_test = test_date["X"], test_date["D"], test_date["Y"], test_date["Z"], test_date["strata"]
# locations = np.array([np.quantile(Y_test, i*0.1) for i in range(1, 10)])
# locations = np.arange(15)
estimator = dte_adj.SimpleStratifiedDistributionEstimator()
estimator.fit(X_test, Z_test, 1-D_test, S_test)
treatment_arm = 1
control_arm = 0
d_t_prediction, d_t_psi, d_t_eta = estimator._compute_cumulative_distribution(treatment_arm, np.zeros(1), X_test, Z_test, 1-D_test) # (L), (N,L), (N,L)
d_c_prediction, d_c_psi, d_c_eta = estimator._compute_cumulative_distribution(control_arm, np.zeros(1), X_test, Z_test, 1-D_test) # (L), (N,L), (N,L)
y_t_prediction, y_t_psi, d_t_mu = estimator._compute_cumulative_distribution(treatment_arm, locations, X_test, Z_test, Y_test) # (L), (N,L), (N,L)
y_c_prediction, y_c_psi, d_c_mu = estimator._compute_cumulative_distribution(control_arm, locations, X_test, Z_test, Y_test) # (L), (N,L), (N,L)
psi_b = (d_t_psi - d_c_psi) # (L)
ldte_test = (y_t_prediction - y_c_prediction) / (d_t_prediction - d_c_prediction) # (L)


empirical_ldtes = []
linear_ldtes = []
xgb_ldtes = []
nn_ldte = []
times = defaultdict(list)
seed(123)
treatment_arm = 1
control_arm = 0
n=1000
for epoch in tqdm.tqdm(range(1000)):
    data = generate_data(n=n)
    X, D, Y, Z, S, W = data["X"], data["D"], data["Y"], data["Z"], data["strata"], data['W']
    X = np.hstack([X, W.reshape(-1, 1)])
    s_list = np.unique(S)
    weights = {s: np.sum((S == s) & (Z == 1)) / sum(S == s) for s in s_list}

    def compute_ldte(estimator, X, Z, D, S, Y, locations, weights, treatment_arm, control_arm):
            estimator.fit(X, Z, 1 - D, S)
            d_t_prediction, d_t_psi, d_t_eta = estimator._compute_cumulative_distribution(treatment_arm, np.zeros(1), X, Z, 1 - D)
            d_c_prediction, d_c_psi, d_c_eta = estimator._compute_cumulative_distribution(control_arm, np.zeros(1), X, Z, 1 - D)
            y_t_prediction, y_t_psi, d_t_mu = estimator._compute_cumulative_distribution(treatment_arm, locations, X, Z, Y)
            y_c_prediction, y_c_psi, d_c_mu = estimator._compute_cumulative_distribution(control_arm, locations, X, Z, Y)
            psi_b = (d_t_psi - d_c_psi)
            beta = (y_t_prediction - y_c_prediction) / (d_t_prediction - d_c_prediction)

            xi_t, xi_c = np.zeros((len(X), len(locations))), np.zeros((len(X), len(locations)))
            for i in range(len(X)):
                    xi_t[i] = ((1 - 1 / weights[S[i]]) * d_t_mu[i] - d_c_mu[i] + (Y[i] <= locations) / weights[S[i]]) - \
                                        beta * ((1 - 1 / weights[S[i]]) * d_t_eta[i] - d_c_eta[i] + D[i] / weights[S[i]])
                    xi_c[i] = ((1 / (1 - weights[S[i]]) - 1) * d_c_mu[i] - d_t_mu[i] + (Y[i] <= locations) / (1 - weights[S[i]])) - \
                                        beta * ((1 / (1 - weights[S[i]]) - 1) * d_c_eta[i] - d_t_eta[i] + D[i] / (1 - weights[S[i]]))

            t_xi_mean = {s: xi_t[(S == s) & (Z == treatment_arm)].mean(axis=0) for s in s_list}
            c_xi_mean = {s: xi_c[(S == s) & (Z == control_arm)].mean(axis=0) for s in s_list}
            for i in range(len(X)):
                    xi_t[i] -= t_xi_mean[S[i]]
                    xi_c[i] -= c_xi_mean[S[i]]

            def xi(s):
                    return ((Y[(S == s) & (Z == treatment_arm)].reshape(-1, 1) < locations.reshape(1, -1)) -
                                    beta.reshape(1, -1) * D[(S == s) & (Z == treatment_arm)].reshape(-1, 1)).mean(axis=0) - \
                                ((Y[(S == s) & (Z == control_arm)].reshape(-1, 1) < locations.reshape(1, -1)) -
                                    beta.reshape(1, -1) * D[(S == s) & (Z == control_arm)].reshape(-1, 1)).mean(axis=0)

            xi_2 = np.array([xi(s) for s in S])
            sigma = (Z.reshape(-1, 1) * xi_t ** 2 + (1 - Z).reshape(-1, 1) * xi_c ** 2 + xi_2 ** 2).mean(axis=0) / (psi_b.mean()) ** 2
            upper_bound = 1.96 * sigma ** 0.5 / np.sqrt(len(X)) + beta
            lower_bound = -1.96 * sigma ** 0.5 / np.sqrt(len(X)) + beta

            return beta, upper_bound, lower_bound

    before = time.time()
    estimator = dte_adj.SimpleStratifiedDistributionEstimator()
    empirical_ldtes.append(compute_ldte(estimator, X, Z, D, S, Y, locations, weights, treatment_arm, control_arm))
    times["empirical"].append(time.time() - before)

    before = time.time()
    estimator = dte_adj.AdjustedStratifiedDistributionEstimator(LinearRegression(), is_multi_task=False, folds=2)
    linear_ldtes.append(compute_ldte(estimator, X, Z, D, S, Y, locations, weights, treatment_arm, control_arm))
    times["linear"].append(time.time() - before)

    before = time.time()
    estimator = dte_adj.AdjustedStratifiedDistributionEstimator(xgb.XGBRegressor(), is_multi_task=False, folds=2)
    xgb_ldtes.append(compute_ldte(estimator, X, Z, D, S, Y, locations, weights, treatment_arm, control_arm))
    times["xgb"].append(time.time() - before)


empirical_ldtes = np.array(empirical_ldtes)
linear_ldtes = np.array(linear_ldtes)
xgb_ldtes = np.array(xgb_ldtes)
nn_ldtes = np.array(nn_ldte)
df = pd.DataFrame({
    "locations": locations,
    "interval length - empirical": (empirical_ldtes[:, 1] - empirical_ldtes[:, 2]).mean(axis=0),
    "coverage probability - empirical": ((empirical_ldtes[:, 1] >= ldte_test) & (ldte_test >= empirical_ldtes[:, 2])).mean(axis=0),
    "RMSE - empirical": ((empirical_ldtes[:,0]-ldte_test)**2).mean(axis=0)**0.5,
    "interval length - linear": (linear_ldtes[:, 1] - linear_ldtes[:, 2]).mean(axis=0),
    "coverage probability - linear": ((linear_ldtes[:, 1] >= ldte_test) & (ldte_test >= linear_ldtes[:, 2])).mean(axis=0),
    "RMSE - linear": ((linear_ldtes[:,0]-ldte_test)**2).mean(axis=0)**0.5,
    "interval length - xgb": (xgb_ldtes[:, 1] - xgb_ldtes[:, 2]).mean(axis=0),
    "coverage probability - xgb": ((xgb_ldtes[:, 1] >= ldte_test) & (ldte_test >= xgb_ldtes[:, 2])).mean(axis=0),
    "RMSE - xgb": ((xgb_ldtes[:,0]-ldte_test)**2).mean(axis=0)**0.5,
})
df["RMSE reduction (%) linear / empirical"] = (1-df["RMSE - linear"]/df["RMSE - empirical"])*100
df["RMSE reduction (%) xgb / empirical"] = (1-df["RMSE - xgb"]/df["RMSE - empirical"])*100
df.to_csv(f"dte_{n}.csv")