from s1_model import (
    CausalMultiTaskDataset, MultiTaskCausalModel, train_model, compute_utility_tensor,
    SingleTaskCompositeModel, train_composite_model, CompositeOutcomeDataset,
    IndependentOutcomeModel, train_independent_model
)

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import sklearn.metrics as metrics
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import time
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ----- Dataset Classes -----
class IHDPDataset(Dataset):
    def __init__(self):
        print("Loading IHDP dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/IHDP.csv", header=None)
        self.X = temp.iloc[:,5:30].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = 25
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

class EHRDataset(Dataset):
    def __init__(self):
        print("Loading EHR dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/EHR.csv").head(1000)
        columns_to_keep = [
            "X..BAND.NEUTROPHILS..count", "X..BAND.NEUTROPHILS..last",
            "X..LYMPHOCYTES.MANUAL..last", "X..LYMPHOCYTES..first", "X..MONOCYTES..last",
            "X..NEUTROPHILS.MANUAL..count", "X..LYMPHOCYTES.MANUAL..count",
            "X..LYMPHOCYTES..count", "X..LYMPHOCYTES..first.1", "X..LYMPHOCYTES..last",
            "ALKALINE.PHOSPHATASE..first", "ALKALINE.PHOSPHATASE..last",
            "ALPHA.1.FETOPROTEINS..first", "ALPHA.1.FETOPROTEINS..last",
            "ALT..count", "ALT..first", "ALT..last", "ANION.GAP..last",
            "AST..first", "AST..last", "BILIRUBIN.DIRECT..count", "BILIRUBIN.INDIRECT..last"]
        self.X = temp[columns_to_keep].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = len(columns_to_keep)
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

# ----- Utility Functions -----
def generate_outcomes(X, A, num_outcomes, data_type):
    """Generate outcomes using the DGP from s2_simulation.py"""
    from s2_simulation import simulate_simplified_dgp
    data = simulate_simplified_dgp(n=len(X), d=X.shape[1], K=num_outcomes,
                                 utility='or',
                                 weights=None,
                                 data=data_type)
    return data['Y']

def evaluate_model(model, X_train, X_test, Y_train, Y_test, model_type, num_outcomes):
    """Evaluate model performance on both training and test data"""
    model.eval()
    with torch.no_grad():
        # Move data to device
        X_train_tensor = torch.FloatTensor(X_train).to(device)
        X_test_tensor = torch.FloatTensor(X_test).to(device)
        Y_train_tensor = torch.FloatTensor(Y_train).to(device)
        Y_test_tensor = torch.FloatTensor(Y_test).to(device)
        
        if model_type == 'CROME':
            # Get predictions for both treatments on training data
            y0_train_pred, y1_train_pred, _ = model.predict_counterfactuals_individual(X_train_tensor)
            u0_train_pred = model.compute_utility(y0_train_pred)
            u1_train_pred = model.compute_utility(y1_train_pred)
            u0_train_true = compute_utility_tensor(Y_train_tensor, utility='or')
            u1_train_true = compute_utility_tensor(Y_train_tensor, utility='or')
            ate_train_error = torch.abs(torch.mean(u1_train_pred - u0_train_pred) - torch.mean(u1_train_true - u0_train_true)).item()
            
            # Get predictions for both treatments on test data
            y0_test_pred, y1_test_pred, _ = model.predict_counterfactuals_individual(X_test_tensor)
            u0_test_pred = model.compute_utility(y0_test_pred)
            u1_test_pred = model.compute_utility(y1_test_pred)
            u0_test_true = compute_utility_tensor(Y_test_tensor, utility='or')
            u1_test_true = compute_utility_tensor(Y_test_tensor, utility='or')
            ate_test_error = torch.abs(torch.mean(u1_test_pred - u0_test_pred) - torch.mean(u1_test_true - u0_test_true)).item()
            
        elif model_type == 'Single-Task':
            # Get predictions for both treatments on training data
            u0_train_pred, u1_train_pred, _ = model.predict_counterfactuals(X_train_tensor)
            u0_train_true = compute_utility_tensor(Y_train_tensor, utility='or')
            u1_train_true = compute_utility_tensor(Y_train_tensor, utility='or')
            ate_train_error = torch.abs(torch.mean(u1_train_pred - u0_train_pred) - torch.mean(u1_train_true - u0_train_true)).item()
            
            # Get predictions for both treatments on test data
            u0_test_pred, u1_test_pred, _ = model.predict_counterfactuals(X_test_tensor)
            u0_test_true = compute_utility_tensor(Y_test_tensor, utility='or')
            u1_test_true = compute_utility_tensor(Y_test_tensor, utility='or')
            ate_test_error = torch.abs(torch.mean(u1_test_pred - u0_test_pred) - torch.mean(u1_test_true - u0_test_true)).item()
            
        elif model_type == 'Multi-Task No Rep':
            # Get predictions for both treatments on training data
            y0_train_pred, y1_train_pred = model.predict_counterfactuals(X_train_tensor)
            u0_train_pred = compute_utility_tensor(y0_train_pred, utility='or')
            u1_train_pred = compute_utility_tensor(y1_train_pred, utility='or')
            u0_train_true = compute_utility_tensor(Y_train_tensor, utility='or')
            u1_train_true = compute_utility_tensor(Y_train_tensor, utility='or')
            ate_train_error = torch.abs(torch.mean(u1_train_pred - u0_train_pred) - torch.mean(u1_train_true - u0_train_true)).item()
            
            # Get predictions for both treatments on test data
            y0_test_pred, y1_test_pred = model.predict_counterfactuals(X_test_tensor)
            u0_test_pred = compute_utility_tensor(y0_test_pred, utility='or')
            u1_test_pred = compute_utility_tensor(y1_test_pred, utility='or')
            u0_test_true = compute_utility_tensor(Y_test_tensor, utility='or')
            u1_test_true = compute_utility_tensor(Y_test_tensor, utility='or')
            ate_test_error = torch.abs(torch.mean(u1_test_pred - u0_test_pred) - torch.mean(u1_test_true - u0_test_true)).item()
        
    return {
        'in_sample_ate_error': ate_train_error,
        'out_of_sample_ate_error': ate_test_error
    }

# ----- Main Experiment Function -----
def run_ablation_study(dataset_name, num_outcomes, num_epochs):
    """
    Run ablation study for a specific dataset and number of components
    
    Args:
        dataset_name (str): Name of the dataset ('synthetic', 'IHDP', 'EHR')
        num_outcomes (int): Number of outcome components to consider
        num_epochs (int): Number of epochs to train the model
    """
    start_time = time.time()
    print(f"\nStarting experiment for {dataset_name} with {num_outcomes} components")
    
    # Load or generate dataset based on dataset_name
    if dataset_name == 'synthetic':
        print("Generating synthetic data...")
        from s2_simulation import simulate_simplified_dgp
        data = simulate_simplified_dgp(n=1000, d=10, K=num_outcomes, 
                                     utility='or', 
                                     weights=None)
        X, A, Y = data['X'], data['A'], data['Y']
    elif dataset_name == 'IHDP':
        dataset = IHDPDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, num_outcomes, 'IHDP')
    elif dataset_name == 'EHR':
        dataset = EHRDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, num_outcomes, 'EHR')
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Split data into train and test
    X_train, X_test, A_train, A_test, Y_train, Y_test = train_test_split(
        X, A, Y, test_size=0.2, random_state=42
    )
    
    # Initialize results dictionary
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': [],
        'Num_Components': []
    }
    
    # Train and evaluate each model
    input_dim = X.shape[1]
    hidden_dim = 64
    
    # 1. CROME Model
    print("Training CROME model...")
    train_dataset = CausalMultiTaskDataset(X_train, A_train, Y_train)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    model = MultiTaskCausalModel(input_dim, hidden_dim, num_outcomes,
                                utility='or',
                                utility_weights=None)
    model = model.to(device)
    train_model(model, train_loader, num_epochs=num_epochs, lr=1e-3)
    metrics = evaluate_model(model, X_train, X_test, Y_train, Y_test, 'CROME', num_outcomes)
    
    # Add CROME results
    results['Model'].extend(['CROME', 'CROME'])
    results['Error_Type'].extend(['In-Sample', 'Out-of-Sample'])
    results['ATE_Error'].extend([metrics['in_sample_ate_error'], metrics['out_of_sample_ate_error']])
    results['Dataset'].extend([dataset_name, dataset_name])
    results['Num_Components'].extend([num_outcomes, num_outcomes])
    
    # 2. Single-Task Composite Model
    print("Training Single-Task Composite model...")
    composite_dataset = CompositeOutcomeDataset(X_train, A_train, Y_train)
    composite_loader = DataLoader(composite_dataset, batch_size=32, shuffle=True)
    model_c = SingleTaskCompositeModel(input_dim, hidden_dim)
    model_c = model_c.to(device)
    train_composite_model(model_c, composite_loader, num_epochs=num_epochs, lr=1e-3)
    metrics = evaluate_model(model_c, X_train, X_test, Y_train, Y_test, 'Single-Task', num_outcomes)
    
    # Add Single-Task results
    results['Model'].extend(['Single-Task', 'Single-Task'])
    results['Error_Type'].extend(['In-Sample', 'Out-of-Sample'])
    results['ATE_Error'].extend([metrics['in_sample_ate_error'], metrics['out_of_sample_ate_error']])
    results['Dataset'].extend([dataset_name, dataset_name])
    results['Num_Components'].extend([num_outcomes, num_outcomes])
    
    # 3. Multi-Task No Rep Model
    print("Training Multi-Task No Rep model...")
    model_indep = IndependentOutcomeModel(input_dim, hidden_dim, num_outcomes)
    model_indep = model_indep.to(device)
    train_independent_model(model_indep, train_loader, num_epochs=num_epochs, lr=1e-3)
    metrics = evaluate_model(model_indep, X_train, X_test, Y_train, Y_test, 'Multi-Task No Rep', num_outcomes)
    
    # Add Multi-Task No Rep results
    results['Model'].extend(['Multi-Task No Rep', 'Multi-Task No Rep'])
    results['Error_Type'].extend(['In-Sample', 'Out-of-Sample'])
    results['ATE_Error'].extend([metrics['in_sample_ate_error'], metrics['out_of_sample_ate_error']])
    results['Dataset'].extend([dataset_name, dataset_name])
    results['Num_Components'].extend([num_outcomes, num_outcomes])
    
    end_time = time.time()
    print(f"Experiment completed in {end_time - start_time:.2f} seconds")
    
    return results

def main():
    # Define datasets and number of components to test
    datasets = ['synthetic', 'IHDP', 'EHR']
    num_components_list = [2, 3, 5, 10, 15, 20]  # Different numbers of components to test
    
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': [],
        'Num_Components': []
    }
    
    total_start_time = time.time()
    
    # Run just one trial
    num_trials = 1
    num_epochs = 50
    
    for dataset in datasets:
        for num_components in num_components_list:
            print(f"\nRunning ablation study for {dataset} with {num_components} components")
            result = run_ablation_study(dataset, num_components, num_epochs=num_epochs)
            
            # Add results to main results dictionary
            for key in results:
                results[key].extend(result[key])
            
            # Save intermediate results
            results_df = pd.DataFrame(results)
            results_df.to_csv('/cbica/home/choii/project/causal_representation/causal_representation_learning/ablation_components_results.csv', index=False)
    
    total_end_time = time.time()
    print(f"\nTotal execution time: {(total_end_time - total_start_time)/3600:.2f} hours")
    
    # Print final results
    print("\nResults:")
    print(results_df)
    
    # Print summary statistics
    print("\nSummary Statistics:")
    grouped = results_df.groupby(["Model", "Error_Type", "Dataset", "Num_Components"])["ATE_Error"]
    mean_errors = grouped.mean()
    se_errors = grouped.std(ddof=1) / np.sqrt(grouped.count())
    
    summary_df = pd.DataFrame({
        "Model": [i[0] for i in mean_errors.index],
        "Type": [i[1] for i in mean_errors.index],
        "Dataset": [i[2] for i in mean_errors.index],
        "Num_Components": [i[3] for i in mean_errors.index],
        "ATE Error": [f"{m:.4f} $\\pm$ {s:.4f}" for m, s in zip(mean_errors, se_errors)]
    })
    print(summary_df)

if __name__ == "__main__":
    main()
