from s1_model import CausalMultiTaskDataset
from s1_model import SharedRepresentation
from s1_model import MultiTaskCausalModel
from s1_model import train_model
from s1_model import SingleTaskCompositeModel
from s1_model import train_composite_model
from s1_model import CompositeOutcomeDataset
from s1_model import IndependentOutcomeModel
from s1_model import train_independent_model
from s1_model import compute_utility_tensor

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import sklearn.metrics as metrics
from tqdm import trange
from sklearn.model_selection import train_test_split
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"



from econml.dml import LinearDML
from econml.dr import DRLearner
from econml.metalearners import TLearner, XLearner
from econml.dml import CausalForestDML
from sklearn.linear_model import LassoCV, LogisticRegressionCV
from sklearn.ensemble import RandomForestRegressor
from econml.metalearners import SLearner
from xgboost import XGBRegressor



# ----- Synthetic Data -----
def simulate_simplified_dgp(n=1000, d=10, K=3, utility='or', weights=None, seed=42, data = 'synthetic'):
    np.random.seed(seed)
    
    
    if data == 'synthetic':
        # Covariates
        X = np.random.randn(n, d)
        A = np.random.binomial(1, 0.5, size=n)
        # Parameters
        beta = np.random.randn(K, d) * 0.5              # linear main effects
        gamma = np.random.randn(K, d) * 0.5             # nonlinear effects
        theta = np.random.randn(K, d) * 0.3             # heterogeneous treatment modifier
        tau_global = 1.0                                 # shared treatment benefit (shift all logits positively)
    elif data == 'IHDP':
        temp = pd.read_csv("C:/Users/Kan Chen/Dropbox/causal_representation_learning/simulation/IHDP.csv", header=None)
        X = temp.iloc[:,5:30].to_numpy(dtype=np.float64)
        n = temp.shape[0]
        d = 25
        A = np.random.binomial(1, 0.5, size=n)
        # Parameters
        beta = np.random.randn(K, d) * 0.5              # linear main effects
        gamma = np.random.randn(K, d) * 0.5             # nonlinear effects
        theta = np.random.randn(K, d) * 0.3             # heterogeneous treatment modifier
        tau_global = 1.0                                 # shared treatment benefit (shift all logits positively)
    elif data == 'EHR':
        temp = pd.read_csv("C:/Users/Kan Chen/Dropbox/causal_representation_learning/simulation/EHR.csv").head(1000)
        columns_to_keep = [
            "X..BAND.NEUTROPHILS..count", "X..BAND.NEUTROPHILS..last",
            "X..LYMPHOCYTES.MANUAL..last", "X..LYMPHOCYTES..first", "X..MONOCYTES..last",
            "X..NEUTROPHILS.MANUAL..count", "X..LYMPHOCYTES.MANUAL..count",
            "X..LYMPHOCYTES..count", "X..LYMPHOCYTES..first.1", "X..LYMPHOCYTES..last",
            "ALKALINE.PHOSPHATASE..first", "ALKALINE.PHOSPHATASE..last",
            "ALPHA.1.FETOPROTEINS..first", "ALPHA.1.FETOPROTEINS..last",
            "ALT..count", "ALT..first", "ALT..last", "ANION.GAP..last",
            "AST..first", "AST..last", "BILIRUBIN.DIRECT..count", "BILIRUBIN.INDIRECT..last"]
        X = temp[columns_to_keep].to_numpy(dtype=np.float64)
        n = temp.shape[0]
        d = len(columns_to_keep)
        A = np.random.binomial(1, 0.5, size=n)
        # Parameters
        beta = np.random.randn(K, d) * 1.5              # linear main effects
        gamma = np.random.randn(K, d) * 1.5             # nonlinear effects
        theta = np.random.randn(K, d) * 1.3             # heterogeneous treatment modifier
        tau_global = 3.0                                 # shared treatment benefit (shift all logits positively)
    else:
        return "Error Message"



    def compute_logit(X, A, k):
        linear = X @ beta[k]
        nonlinear = np.tanh(0, X @ gamma[k])      # tanh nonlinearity
        treatment_effect = A * (tau_global + np.tanh(X @ theta[k]))  # heterogeneous, centered around positive value
        return linear + nonlinear + treatment_effect

    # Potential outcomes
    Y0 = np.zeros((n, K))
    Y1 = np.zeros((n, K))
    for k in range(K):
        logit0 = compute_logit(X, np.zeros(n), k)
        logit1 = compute_logit(X, np.ones(n), k)
        #p0 = 1 / (1 + np.exp(-logit0))
        #p1 = 1 / (1 + np.exp(-logit1))
        thresh0 = np.quantile(logit0, 0.4)
        thresh1 = np.quantile(logit1, 0.4)
        Y0[:, k] = (logit0 > thresh0).astype(float)
        Y1[:, k] = (logit1 > thresh1).astype(float)


    # Observed outcomes
    Y = np.where(A[:, None] == 1, Y1, Y0)

    # Composite utility function
    def compute_utility(Y, utility, weights):
        if utility == 'or':
            return 1 - np.prod(1 - Y, axis=1)
        elif utility == 'weighted_sum':
            w = np.array(weights if weights is not None else [1.0] * K)
            return Y @ w
        elif utility == 'tanh':
            w = np.array(weights if weights is not None else [1.0] * K)
            return np.tanh(Y @ w)
        else:
            raise ValueError(f"Unsupported utility: {utility}")

    U0 = compute_utility(Y0, utility, weights)
    U1 = compute_utility(Y1, utility, weights)
    U = np.where(A == 1, U1, U0)
    tau = U1 - U0

    return {
        'X': X,
        'A': A,
        'Y': Y,
        'Y0': Y0,
        'Y1': Y1,
        'U': U,
        'U0': U0,
        'U1': U1,
        'tau': tau
    }





results = {
    'Model': [],
    'Error_Type': [],
    'ATE_Error': []
}

num_trials = 1000
n_epoches = 200

for _ in trange(num_trials):
    data = simulate_simplified_dgp(n=1000, d=10, K=3, data = 'EHR')
    X, A, Y = data['X'], data['A'], data['Y']
    U, tau_true = data['U'], data['tau']
    
    K = Y.shape[1]
    p = X.shape[1]

    # --- Train/Test Split ---
    idx_train, idx_test = train_test_split(np.arange(X.shape[0]), test_size=0.2, random_state=None)
    X_train, A_train, Y_train = X[idx_train], A[idx_train], Y[idx_train]
    U_train, tau_true_train = U[idx_train], tau_true[idx_train]
    X_test, A_test, Y_test = X[idx_test], A[idx_test], Y[idx_test]
    U_test, tau_true_test = U[idx_test], tau_true[idx_test]

    # === Multi-Task Shared Rep ===
    dataset = CausalMultiTaskDataset(X_train, A_train, Y_train)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    model = MultiTaskCausalModel(input_dim=p, hidden_dim=64, num_outcomes=K)
    train_model(model, dataloader, num_epochs=n_epoches, lr=1e-3)
    model.eval()

    # In-sample
    x_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    _, _, tau_train = model.predict_counterfactuals(x_train_tensor)
    ate_error_train = np.abs(tau_train.mean().item() - np.mean(tau_true_train))
    results['Model'].append('Multi-Task + Shared Rep')
    results['Error_Type'].append('In-Sample')
    results['ATE_Error'].append(ate_error_train)

    # Out-of-sample
    x_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    _, _, tau_test = model.predict_counterfactuals(x_test_tensor)
    ate_error_test = np.abs(tau_test.mean().item() - np.mean(tau_true_test))
    results['Model'].append('Multi-Task + Shared Rep')
    results['Error_Type'].append('Out-of-Sample')
    results['ATE_Error'].append(ate_error_test)

    # === Repeat for Single-Task Composite ===
    composite_dataset = CompositeOutcomeDataset(X_train, A_train, U_train)
    composite_loader = DataLoader(composite_dataset, batch_size=32, shuffle=True)
    model_c = SingleTaskCompositeModel(input_dim=p, hidden_dim=64)
    train_composite_model(model_c, composite_loader, num_epochs=n_epoches, lr=1e-3)
    model_c.eval()

    u0_c_train, u1_c_train, tau_c_train = model_c.predict_counterfactuals(x_train_tensor)
    ate_error_c_train = np.abs(tau_c_train.mean().item() - np.mean(tau_true_train))
    results['Model'].append('Single-Task Composite')
    results['Error_Type'].append('In-Sample')
    results['ATE_Error'].append(ate_error_c_train)

    u0_c_test, u1_c_test, tau_c_test = model_c.predict_counterfactuals(x_test_tensor)
    ate_error_c_test = np.abs(tau_c_test.mean().item() - np.mean(tau_true_test))
    results['Model'].append('Single-Task Composite')
    results['Error_Type'].append('Out-of-Sample')
    results['ATE_Error'].append(ate_error_c_test)

    # === Repeat for Independent Model ===
    model_indep = IndependentOutcomeModel(input_dim=p, hidden_dim=64, num_outcomes=K)
    train_independent_model(model_indep, dataloader, num_epochs=n_epoches, lr=1e-3)
    model_indep.eval()

    y0_i_train, y1_i_train = model_indep.predict_counterfactuals(x_train_tensor)
    u0_i_train = compute_utility_tensor(y0_i_train, utility='or')
    u1_i_train = compute_utility_tensor(y1_i_train, utility='or')
    tau_i_train = (u1_i_train - u0_i_train).detach().numpy().flatten()
    ate_error_i_train = np.abs(tau_i_train.mean() - np.mean(tau_true_train))
    results['Model'].append('Multi-Task No Rep')
    results['Error_Type'].append('In-Sample')
    results['ATE_Error'].append(ate_error_i_train)

    y0_i_test, y1_i_test = model_indep.predict_counterfactuals(x_test_tensor)
    u0_i_test = compute_utility_tensor(y0_i_test, utility='or')
    u1_i_test = compute_utility_tensor(y1_i_test, utility='or')
    tau_i_test = (u1_i_test - u0_i_test).detach().numpy().flatten()
    ate_error_i_test = np.abs(tau_i_test.mean() - np.mean(tau_true_test))
    results['Model'].append('Multi-Task No Rep')
    results['Error_Type'].append('Out-of-Sample')
    results['ATE_Error'].append(ate_error_i_test)
    
    # === EconML Benchmarks ===
    econml_models = {
        "DML": LinearDML(discrete_treatment=True),
        "Causal Forest": CausalForestDML(discrete_treatment=True),
        "DR Learner": DRLearner(model_regression=RandomForestRegressor(), model_propensity=LogisticRegressionCV()),
        "T Learner": TLearner(models=RandomForestRegressor()),
        "X Learner": XLearner(models=RandomForestRegressor()),
        "S Learner": SLearner(overall_model=RandomForestRegressor()),
        "BART": DRLearner(model_propensity=LogisticRegressionCV(), model_regression=XGBRegressor(n_estimators=100, max_depth=3),
        model_final=XGBRegressor(n_estimators=100, max_depth=3)
)

    }
    
    for name, est in econml_models.items():
        
   
        est.fit(Y=U_train, T=A_train, X=X_train)
    
        # In-sample
        tau_pred_train = est.effect(X_train)
        ate_error_train = np.abs(np.mean(tau_pred_train) - np.mean(tau_true_train))
        results['Model'].append(name)
        results['Error_Type'].append('In-Sample')
        results['ATE_Error'].append(ate_error_train)
    
        # Out-of-sample
        tau_pred_test = est.effect(X_test)
        ate_error_test = np.abs(np.mean(tau_pred_test) - np.mean(tau_true_test))
        results['Model'].append(name)
        results['Error_Type'].append('Out-of-Sample')
        results['ATE_Error'].append(ate_error_test)
    
    
    

# --- Summarize results ---
df = pd.DataFrame(results)
grouped = df.groupby(["Model", "Error_Type"])["ATE_Error"]
mean_errors = grouped.mean()
se_errors = grouped.std(ddof=1) / np.sqrt(grouped.count())

summary_df = pd.DataFrame({
    "Model": [i[0] for i in mean_errors.index],
    "Type": [i[1] for i in mean_errors.index],
    "ATE Error": [f"{m:.4f} $\\pm$ {s:.4f}" for m, s in zip(mean_errors, se_errors)]
})

# LaTeX table
latex_table = summary_df.to_latex(index=False, escape=False)
print(latex_table)



