from s1_model import CausalMultiTaskDataset, MultiTaskCausalModel, train_model, compute_utility_tensor

import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import sklearn.metrics as metrics
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import time
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ----- Dataset Classes -----
class IHDPDataset(Dataset):
    def __init__(self):
        print("Loading IHDP dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/IHDP.csv", header=None)
        self.X = temp.iloc[:,5:30].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = 25
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

class EHRDataset(Dataset):
    def __init__(self):
        print("Loading EHR dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/EHR.csv").head(1000)
        columns_to_keep = [
            "X..BAND.NEUTROPHILS..count", "X..BAND.NEUTROPHILS..last",
            "X..LYMPHOCYTES.MANUAL..last", "X..LYMPHOCYTES..first", "X..MONOCYTES..last",
            "X..NEUTROPHILS.MANUAL..count", "X..LYMPHOCYTES.MANUAL..count",
            "X..LYMPHOCYTES..count", "X..LYMPHOCYTES..first.1", "X..LYMPHOCYTES..last",
            "ALKALINE.PHOSPHATASE..first", "ALKALINE.PHOSPHATASE..last",
            "ALPHA.1.FETOPROTEINS..first", "ALPHA.1.FETOPROTEINS..last",
            "ALT..count", "ALT..first", "ALT..last", "ANION.GAP..last",
            "AST..first", "AST..last", "BILIRUBIN.DIRECT..count", "BILIRUBIN.INDIRECT..last"]
        self.X = temp[columns_to_keep].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = len(columns_to_keep)
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

# ----- Utility Functions -----
def generate_outcomes(X, A, num_outcomes, utility_type, weights, data_type):
    """Generate outcomes using the DGP from s2_simulation.py"""
    from s2_simulation import simulate_simplified_dgp
    data = simulate_simplified_dgp(n=len(X), d=X.shape[1], K=num_outcomes,
                                 utility=utility_type,
                                 weights=weights,
                                 data=data_type)
    return data['Y']

def evaluate_model(model, X_train, X_test, Y_train, Y_test, utility_type, weights):
    """Evaluate model performance on both training and test data"""
    model.eval()
    with torch.no_grad():
        # Move data to device
        X_train_tensor = torch.FloatTensor(X_train).to(device)
        X_test_tensor = torch.FloatTensor(X_test).to(device)
        Y_train_tensor = torch.FloatTensor(Y_train).to(device)
        Y_test_tensor = torch.FloatTensor(Y_test).to(device)
        
        # Get predictions for both treatments on training data
        y0_train_pred, y1_train_pred, tau_train_pred = model.predict_counterfactuals_individual(X_train_tensor)
        u0_train_pred = model.compute_utility(y0_train_pred)
        u1_train_pred = model.compute_utility(y1_train_pred)
        u0_train_true = compute_utility_tensor(Y_train_tensor, utility=utility_type, weights=weights)
        u1_train_true = compute_utility_tensor(Y_train_tensor, utility=utility_type, weights=weights)
        ate_train_error = torch.abs(torch.mean(u1_train_pred - u0_train_pred) - torch.mean(u1_train_true - u0_train_true)).item()
        
        # Get predictions for both treatments on test data
        y0_test_pred, y1_test_pred, tau_test_pred = model.predict_counterfactuals_individual(X_test_tensor)
        u0_test_pred = model.compute_utility(y0_test_pred)
        u1_test_pred = model.compute_utility(y1_test_pred)
        u0_test_true = compute_utility_tensor(Y_test_tensor, utility=utility_type, weights=weights)
        u1_test_true = compute_utility_tensor(Y_test_tensor, utility=utility_type, weights=weights)
        ate_test_error = torch.abs(torch.mean(u1_test_pred - u0_test_pred) - torch.mean(u1_test_true - u0_test_true)).item()
        
    return {
        'in_sample_ate_error': ate_train_error,
        'out_of_sample_ate_error': ate_test_error
    }

# ----- Main Experiment Function -----
def run_ablation_study(dataset_name, utility_type, num_outcomes=3):
    """
    Run ablation study for a specific dataset and utility type
    
    Args:
        dataset_name (str): Name of the dataset ('synthetic', 'IHDP', 'EHR')
        utility_type (str): Type of utility to test ('weighted_sum', 'tanh_reward')
        num_outcomes (int): Number of outcomes to consider
    """
    start_time = time.time()
    print(f"\nStarting experiment for {dataset_name} with {utility_type}")
    
    # Set equal weights for all outcomes
    utility_weights = [1.0] * num_outcomes
    
    # Load or generate dataset based on dataset_name
    if dataset_name == 'synthetic':
        print("Generating synthetic data...")
        from s2_simulation import simulate_simplified_dgp
        data = simulate_simplified_dgp(n=1000, d=10, K=num_outcomes, 
                                     utility=utility_type, 
                                     weights=utility_weights)
        X, A, Y = data['X'], data['A'], data['Y']
    elif dataset_name == 'IHDP':
        dataset = IHDPDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, num_outcomes, utility_type, utility_weights, 'IHDP')
    elif dataset_name == 'EHR':
        dataset = EHRDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, num_outcomes, utility_type, utility_weights, 'EHR')
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Split data into train and test
    X_train, X_test, A_train, A_test, Y_train, Y_test = train_test_split(
        X, A, Y, test_size=0.2, random_state=42
    )
    
    # Create datasets and dataloaders
    train_dataset = CausalMultiTaskDataset(X_train, A_train, Y_train)
    test_dataset = CausalMultiTaskDataset(X_test, A_test, Y_test)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    # Initialize and train CROME model
    input_dim = X.shape[1]
    hidden_dim = 64
    model = MultiTaskCausalModel(input_dim, hidden_dim, num_outcomes,
                                utility=utility_type,
                                utility_weights=utility_weights)
    model = model.to(device)
    
    print("Training CROME model...")
    train_model(model, train_loader, num_epochs=50, lr=1e-3)
    
    # Evaluate model on both training and test data
    print("Evaluating model...")
    metrics = evaluate_model(model, X_train, X_test, Y_train, Y_test, utility_type, utility_weights)
    
    end_time = time.time()
    print(f"Experiment completed in {end_time - start_time:.2f} seconds")
    
    return {
        'dataset': dataset_name,
        'utility_type': utility_type,
        'num_outcomes': num_outcomes,
        **metrics
    }

def main():
    # Define datasets and utility types to test
    datasets = ['synthetic', 'IHDP', 'EHR']
    utility_types = ['or','weighted_sum', 'tanh_reward']
    
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Utility_Type': [],
        'Dataset': []  # Add dataset to results
    }
    
    total_start_time = time.time()
    
    for dataset in datasets:
        for utility_type in utility_types:
            print(f"\nRunning ablation study for {dataset} with {utility_type}")
            result = run_ablation_study(dataset, utility_type)
            
            # Add in-sample results
            results['Model'].append('CROME')
            results['Error_Type'].append('In-Sample')
            results['ATE_Error'].append(result['in_sample_ate_error'])
            results['Utility_Type'].append(utility_type)
            results['Dataset'].append(dataset)
            
            # Add out-of-sample results
            results['Model'].append('CROME')
            results['Error_Type'].append('Out-of-Sample')
            results['ATE_Error'].append(result['out_of_sample_ate_error'])
            results['Utility_Type'].append(utility_type)
            results['Dataset'].append(dataset)
            
            # Save intermediate results
            results_df = pd.DataFrame(results)
            results_df.to_csv('/cbica/home/choii/project/causal_representation/causal_representation_learning/ablation_results.csv', index=False)
    
    total_end_time = time.time()
    print(f"\nTotal execution time: {(total_end_time - total_start_time)/3600:.2f} hours")
    
    # Print final results
    print("\nResults:")
    print(results_df)
    
    # Print summary statistics
    print("\nSummary Statistics:")
    grouped = results_df.groupby(["Model", "Error_Type", "Utility_Type", "Dataset"])["ATE_Error"]
    mean_errors = grouped.mean()
    se_errors = grouped.std(ddof=1) / np.sqrt(grouped.count())
    
    summary_df = pd.DataFrame({
        "Model": [i[0] for i in mean_errors.index],
        "Type": [i[1] for i in mean_errors.index],
        "Utility": [i[2] for i in mean_errors.index],
        "Dataset": [i[3] for i in mean_errors.index],
        "ATE Error": [f"{m:.4f} $\\pm$ {s:.4f}" for m, s in zip(mean_errors, se_errors)]
    })
    print(summary_df)

if __name__ == "__main__":
    main()



