import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import time
import os
from s1_model import CausalMultiTaskDataset
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# -------- Epsilon Layer --------
class EpsilonLayer(nn.Module):
    def __init__(self):
        super(EpsilonLayer, self).__init__()
        self.epsilon = nn.Parameter(torch.randn(1))

    def forward(self, t_pred):
        batch_size = t_pred.size(0)
        return self.epsilon.expand(batch_size, 1)

# -------- Dragonnet Model --------
class Dragonnet(nn.Module):
    def __init__(self, input_dim, num_outcomes=3, reg_l2=0.0):
        super(Dragonnet, self).__init__()
        self.num_outcomes = num_outcomes
        
        # Shared representation layers
        self.shared = nn.Sequential(
            nn.Linear(input_dim, 200),
            nn.ELU(),
            nn.Linear(200, 200),
            nn.ELU(),
            nn.Linear(200, 200),
            nn.ELU()
        )
        
        # Treatment prediction head
        self.t_pred = nn.Sequential(
            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        
        # Outcome heads
        self.y0_head = nn.Sequential(
            nn.Linear(200, 100),
            nn.ELU(),
            nn.Linear(100, 100),
            nn.ELU(),
            nn.Linear(100, num_outcomes)
        )
        self.y1_head = nn.Sequential(
            nn.Linear(200, 100),
            nn.ELU(),
            nn.Linear(100, 100),
            nn.ELU(),
            nn.Linear(100, num_outcomes)
        )
        
        # Epsilon layer
        self.epsilon_layer = EpsilonLayer()
        
    def forward(self, x):
        rep = self.shared(x)
        
        t = self.t_pred(rep)
        y0 = self.y0_head(rep)
        y1 = self.y1_head(rep)
        eps = self.epsilon_layer(t)
        
        # Output shape: (batch_size, 2*num_outcomes + 2) -> [y0_pred, y1_pred, t_pred, eps]
        out = torch.cat([y0, y1, t, eps], dim=1)
        return out

    def predict_counterfactuals(self, x):
        """Predict counterfactual outcomes for all samples"""
        self.eval()
        with torch.no_grad():
            x = x.to(device)
            outputs = self(x)
            y0_pred = outputs[:, :self.num_outcomes]
            y1_pred = outputs[:, self.num_outcomes:2*self.num_outcomes]
        return y0_pred, y1_pred

# -------- Loss functions --------

def dragonnet_loss(y_true, t_true, outputs, num_outcomes=3):
    """
    Loss = Regression Loss (for Y0/Y1) + Binary Crossentropy (for T)
    """
    y0_pred = outputs[:, :num_outcomes]
    y1_pred = outputs[:, num_outcomes:2*num_outcomes]
    t_pred = outputs[:, -2]
    
    # Regression loss: squared error depending on t_true
    t_true = t_true.view(-1)  # Reshape to match t_pred shape
    loss0 = (1. - t_true).view(-1, 1) * torch.mean((y_true - y0_pred)**2, dim=1, keepdim=True)
    loss1 = t_true.view(-1, 1) * torch.mean((y_true - y1_pred)**2, dim=1, keepdim=True)
    regression_loss = (loss0 + loss1).mean()
    
    # Treatment prediction (binary classification) loss
    t_pred = torch.clamp(t_pred, 1e-3, 1-1e-3)  # for numerical stability
    classification_loss = F.binary_cross_entropy(t_pred, t_true)
    
    total_loss = regression_loss + classification_loss
    return total_loss

# ----- Dataset Classes -----
class IHDPDataset(Dataset):
    def __init__(self):
        print("Loading IHDP dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/IHDP.csv", header=None)
        self.X = temp.iloc[:,5:30].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = 25
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

class EHRDataset(Dataset):
    def __init__(self):
        print("Loading EHR dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/EHR.csv").head(1000)
        columns_to_keep = [
            "X..BAND.NEUTROPHILS..count", "X..BAND.NEUTROPHILS..last",
            "X..LYMPHOCYTES.MANUAL..last", "X..LYMPHOCYTES..first", "X..MONOCYTES..last",
            "X..NEUTROPHILS.MANUAL..count", "X..LYMPHOCYTES.MANUAL..count",
            "X..LYMPHOCYTES..count", "X..LYMPHOCYTES..first.1", "X..LYMPHOCYTES..last",
            "ALKALINE.PHOSPHATASE..first", "ALKALINE.PHOSPHATASE..last",
            "ALPHA.1.FETOPROTEINS..first", "ALPHA.1.FETOPROTEINS..last",
            "ALT..count", "ALT..first", "ALT..last", "ANION.GAP..last",
            "AST..first", "AST..last", "BILIRUBIN.DIRECT..count", "BILIRUBIN.INDIRECT..last"]
        self.X = temp[columns_to_keep].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = len(columns_to_keep)
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

# ----- Utility Functions -----
def generate_outcomes(X, A, num_outcomes, data_type):
    """Generate outcomes using the DGP from s2_simulation.py"""
    from s2_simulation import simulate_simplified_dgp
    data = simulate_simplified_dgp(n=len(X), d=X.shape[1], K=num_outcomes,
                                 utility='or',
                                 weights=None,
                                 data=data_type)
    return data['Y']

def evaluate_model(model, X_train, X_test, Y_train, Y_test, A_train, A_test):
    """Evaluate model performance on both training and test data"""
    model.eval()
    with torch.no_grad():
        # Move data to device
        X_train_tensor = torch.FloatTensor(X_train).to(device)
        X_test_tensor = torch.FloatTensor(X_test).to(device)
        Y_train_tensor = torch.FloatTensor(Y_train).to(device)
        Y_test_tensor = torch.FloatTensor(Y_test).to(device)
        A_train_tensor = torch.FloatTensor(A_train).to(device)
        A_test_tensor = torch.FloatTensor(A_test).to(device)
        
        # Get predictions for both treatments on training data
        y0_train_pred, y1_train_pred = model.predict_counterfactuals(X_train_tensor)
        ate_train_pred = torch.mean(y1_train_pred - y0_train_pred)
        
        # Calculate true ATE using IPW (Inverse Probability Weighting)
        # This handles imbalanced treatment groups
        n_train = len(Y_train_tensor)
        p_treated = torch.mean(A_train_tensor)
        weights = A_train_tensor / p_treated - (1 - A_train_tensor) / (1 - p_treated)
        ate_train_true = torch.mean(weights.view(-1, 1) * Y_train_tensor)
        
        # Get predictions for both treatments on test data
        y0_test_pred, y1_test_pred = model.predict_counterfactuals(X_test_tensor)
        ate_test_pred = torch.mean(y1_test_pred - y0_test_pred)
        
        # Calculate true ATE using IPW for test data
        n_test = len(Y_test_tensor)
        p_treated_test = torch.mean(A_test_tensor)
        weights_test = A_test_tensor / p_treated_test - (1 - A_test_tensor) / (1 - p_treated_test)
        ate_test_true = torch.mean(weights_test.view(-1, 1) * Y_test_tensor)
        
        ate_train_error = torch.abs(ate_train_pred - ate_train_true).item()
        ate_test_error = torch.abs(ate_test_pred - ate_test_true).item()
        
    return {
        'in_sample_ate_error': ate_train_error,
        'out_of_sample_ate_error': ate_test_error
    }

def train_dragonnet(model, train_loader, num_epochs=50, lr=1e-3):
    """Train Dragonnet model"""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for x_batch, a_batch, y_batch in train_loader:
            # Move batch data to device
            x_batch = x_batch.to(device)
            a_batch = a_batch.to(device)
            y_batch = y_batch.to(device)
            
            optimizer.zero_grad()
            outputs = model(x_batch)
            loss = dragonnet_loss(y_batch, a_batch, outputs, num_outcomes=model.num_outcomes)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")

# ----- Main Experiment Function -----
def run_experiment(dataset_name, num_epochs=50):
    """
    Run experiment for a specific dataset
    
    Args:
        dataset_name (str): Name of the dataset ('synthetic', 'IHDP', 'EHR')
        num_epochs (int): Number of epochs to train the model
    """
    start_time = time.time()
    print(f"\nStarting experiment for {dataset_name}")
    
    # Load or generate dataset based on dataset_name
    if dataset_name == 'synthetic':
        print("Generating synthetic data...")
        from s2_simulation import simulate_simplified_dgp
        data = simulate_simplified_dgp(n=1000, d=10, K=3, 
                                     utility='or', 
                                     weights=None)
        X, A, Y = data['X'], data['A'], data['Y']
    elif dataset_name == 'IHDP':
        dataset = IHDPDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 3, 'IHDP')
    elif dataset_name == 'EHR':
        dataset = EHRDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 3, 'EHR')
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Split data into train and test
    X_train, X_test, A_train, A_test, Y_train, Y_test = train_test_split(
        X, A, Y, test_size=0.2, random_state=42
    )
    
    # Initialize results dictionary
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': []
    }
    
    # Train and evaluate Dragonnet model
    input_dim = X.shape[1]
    num_outcomes = Y.shape[1]
    
    print("Training Dragonnet model...")
    train_dataset = CausalMultiTaskDataset(X_train, A_train, Y_train)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    model = Dragonnet(input_dim, num_outcomes=num_outcomes)
    model = model.to(device)
    train_dragonnet(model, train_loader, num_epochs=num_epochs, lr=1e-3)
    metrics = evaluate_model(model, X_train, X_test, Y_train, Y_test, A_train, A_test)
    
    # Add Dragonnet results
    results['Model'].extend(['Dragonnet', 'Dragonnet'])
    results['Error_Type'].extend(['In-Sample', 'Out-of-Sample'])
    results['ATE_Error'].extend([metrics['in_sample_ate_error'], metrics['out_of_sample_ate_error']])
    results['Dataset'].extend([dataset_name, dataset_name])
    
    end_time = time.time()
    print(f"Experiment completed in {end_time - start_time:.2f} seconds")
    
    return results

def main():
    # Define datasets to test
    datasets = ['synthetic', 'IHDP', 'EHR']
    
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': []
    }
    
    total_start_time = time.time()
    
    # Run just one trial
    num_trials = 1
    num_epochs = 50
    
    for dataset in datasets:
        print(f"\nRunning experiment for {dataset}")
        result = run_experiment(dataset, num_epochs=num_epochs)
        
        # Add results to main results dictionary
        for key in results:
            results[key].extend(result[key])
        
        # Save intermediate results
        results_df = pd.DataFrame(results)
        results_df.to_csv('/cbica/home/choii/project/causal_representation/causal_representation_learning/dragonnet_results.csv', index=False)
    
    total_end_time = time.time()
    print(f"\nTotal execution time: {(total_end_time - total_start_time)/3600:.2f} hours")
    
    # Print final results
    print("\nResults:")
    print(results_df)
    
    # Print summary statistics
    print("\nSummary Statistics:")
    grouped = results_df.groupby(["Model", "Error_Type", "Dataset"])["ATE_Error"]
    mean_errors = grouped.mean()
    se_errors = grouped.std(ddof=1) / np.sqrt(grouped.count())
    
    summary_df = pd.DataFrame({
        "Model": [i[0] for i in mean_errors.index],
        "Type": [i[1] for i in mean_errors.index],
        "Dataset": [i[2] for i in mean_errors.index],
        "ATE Error": [f"{m:.4f} $\\pm$ {s:.4f}" for m, s in zip(mean_errors, se_errors)]
    })
    print(summary_df)

if __name__ == "__main__":
    main()

