import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import time
import os
import sys
from tqdm import tqdm

# Add the code directory to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from s2_simulation import simulate_simplified_dgp

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ----- Model Architecture -----
class RepresentationNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.model(x)

class OutcomeHead(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        return self.model(x)

class TARNet(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.representation = RepresentationNetwork(input_dim, hidden_dim)
        self.y0_head = OutcomeHead(hidden_dim)
        self.y1_head = OutcomeHead(hidden_dim)

    def forward(self, x, t):
        phi = self.representation(x)
        y0 = self.y0_head(phi)
        y1 = self.y1_head(phi)
        y = t.view(-1, 1) * y1 + (1 - t.view(-1, 1)) * y0
        return y, y0, y1

# ----- Dataset Classes -----
class IHDPDataset(Dataset):
    def __init__(self):
        print("Loading IHDP dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/IHDP.csv", header=None)
        self.X = temp.iloc[:,5:30].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = 25
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

class EHRDataset(Dataset):
    def __init__(self):
        print("Loading EHR dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/EHR.csv").head(1000)
        columns_to_keep = [
            "X..BAND.NEUTROPHILS..count", "X..BAND.NEUTROPHILS..last",
            "X..LYMPHOCYTES.MANUAL..last", "X..LYMPHOCYTES..first", "X..MONOCYTES..last",
            "X..NEUTROPHILS.MANUAL..count", "X..LYMPHOCYTES.MANUAL..count",
            "X..LYMPHOCYTES..count", "X..LYMPHOCYTES..first.1", "X..LYMPHOCYTES..last",
            "ALKALINE.PHOSPHATASE..first", "ALKALINE.PHOSPHATASE..last",
            "ALPHA.1.FETOPROTEINS..first", "ALPHA.1.FETOPROTEINS..last",
            "ALT..count", "ALT..first", "ALT..last", "ANION.GAP..last",
            "AST..first", "AST..last", "BILIRUBIN.DIRECT..count", "BILIRUBIN.INDIRECT..last"]
        self.X = temp[columns_to_keep].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = len(columns_to_keep)
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

# ----- Utility Functions -----
def generate_outcomes(X, A, data_type):
    """Generate outcomes using the DGP from s2_simulation.py"""
    from s2_simulation import simulate_simplified_dgp
    data = simulate_simplified_dgp(n=len(X), d=X.shape[1], K=1,
                                 utility='or',
                                 weights=None,
                                 data=data_type)
    return data['Y']

def evaluate_model(model, X_train, X_test, Y_train, Y_test, A_train, A_test):
    """Evaluate model performance on both training and test data"""
    model.eval()
    with torch.no_grad():
        # Move data to device
        X_train_tensor = torch.FloatTensor(X_train).to(device)
        X_test_tensor = torch.FloatTensor(X_test).to(device)
        A_train_tensor = torch.FloatTensor(A_train).to(device)
        A_test_tensor = torch.FloatTensor(A_test).to(device)
        
        # Get predictions for both treatments
        _, y0_train_pred, y1_train_pred = model(X_train_tensor, A_train_tensor)
        _, y0_test_pred, y1_test_pred = model(X_test_tensor, A_test_tensor)
        
        # Calculate ATE
        ate_train_pred = torch.mean(y1_train_pred - y0_train_pred).item()
        ate_test_pred = torch.mean(y1_test_pred - y0_test_pred).item()
        
        # Calculate true ATE
        ate_train_true = np.mean(Y_train[A_train == 1]) - np.mean(Y_train[A_train == 0])
        ate_test_true = np.mean(Y_test[A_test == 1]) - np.mean(Y_test[A_test == 0])
        
        # Calculate errors
        ate_train_error = np.abs(ate_train_pred - ate_train_true)
        ate_test_error = np.abs(ate_test_pred - ate_test_true)
        
        # Print detailed metrics
        print(f"True Train ATE: {ate_train_true:.4f}")
        print(f"Pred Train ATE: {ate_train_pred:.4f}")
        print(f"True Test ATE: {ate_test_true:.4f}")
        print(f"Pred Test ATE: {ate_test_pred:.4f}")
    
    return {
        'in_sample_ate_error': ate_train_error,
        'out_of_sample_ate_error': ate_test_error
    }

def run_experiment(dataset_name, num_epochs=50):
    """
    Run experiment for a specific dataset
    
    Args:
        dataset_name (str): Name of the dataset ('synthetic', 'IHDP', 'EHR')
        num_epochs (int): Number of epochs to train the model
    """
    start_time = time.time()
    print(f"\nStarting experiment for {dataset_name}")
    
    # Load or generate dataset based on dataset_name
    if dataset_name == 'synthetic':
        print("Generating synthetic data...")
        data = simulate_simplified_dgp(n=1000, d=10, K=1, 
                                     utility='or', 
                                     weights=None)
        X, A, Y = data['X'], data['A'], data['Y']
    elif dataset_name == 'IHDP':
        dataset = IHDPDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 'IHDP')
    elif dataset_name == 'EHR':
        dataset = EHRDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 'EHR')
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Split data into train and test
    X_train, X_test, A_train, A_test, Y_train, Y_test = train_test_split(
        X, A, Y, test_size=0.2, random_state=42, stratify=A
    )
    
    # Initialize results dictionary
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': []
    }
    
    # Train and evaluate TARNet model
    input_dim = X.shape[1]
    hidden_dim = 64
    
    print("Training TARNet model...")
    train_dataset = torch.utils.data.TensorDataset(
        torch.FloatTensor(X_train),
        torch.FloatTensor(A_train),
        torch.FloatTensor(Y_train)
    )
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    
    model = TARNet(input_dim, hidden_dim)
    model = model.to(device)
    
    # Training loop
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for x_batch, a_batch, y_batch in train_loader:
            x_batch = x_batch.to(device)
            a_batch = a_batch.to(device)
            y_batch = y_batch.to(device)
            
            optimizer.zero_grad()
            y_pred, _, _ = model(x_batch, a_batch)
            loss = criterion(y_pred, y_batch.view(-1, 1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")
    
    # Evaluate model
    metrics = evaluate_model(model, X_train, X_test, Y_train, Y_test, A_train, A_test)
    
    # Add TARNet results
    results['Model'].extend(['TARNet', 'TARNet'])
    results['Error_Type'].extend(['In-Sample', 'Out-of-Sample'])
    results['ATE_Error'].extend([metrics['in_sample_ate_error'], metrics['out_of_sample_ate_error']])
    results['Dataset'].extend([dataset_name, dataset_name])
    
    end_time = time.time()
    print(f"Experiment completed in {end_time - start_time:.2f} seconds")
    
    return results

def main():
    # Define datasets to test
    datasets = ['synthetic', 'IHDP', 'EHR']
    
    # Clear any previous results
    results_file = '/cbica/home/choii/project/causal_representation/causal_representation_learning/tarnet_results.csv'
    if os.path.exists(results_file):
        os.remove(results_file)
    
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': []
    }
    
    total_start_time = time.time()
    
    # Run experiments
    for dataset in datasets:
        print(f"\nRunning experiment for {dataset}")
        result = run_experiment(dataset, num_epochs=50)
        
        # Add results to main results dictionary
        for key in results:
            results[key].extend(result[key])
        
        # Save intermediate results
        results_df = pd.DataFrame(results)
        results_df.to_csv(results_file, index=False)
    
    total_end_time = time.time()
    print(f"\nTotal execution time: {(total_end_time - total_start_time)/3600:.2f} hours")
    
    # Print final results
    print("\nResults:")
    print(results_df)
    
    # Print summary statistics in LaTeX table format
    print("\nSummary Statistics:")
    grouped = results_df.groupby(["Model", "Error_Type", "Dataset"])["ATE_Error"]
    mean_errors = grouped.mean()
    se_errors = grouped.std(ddof=1) / np.sqrt(grouped.count())
    
    # Create LaTeX table
    print("\\begin{tabular}{llll}")
    print("\\toprule")
    print("Model & Dataset & Type & ATE Error \\\\")
    print("\\midrule")
    
    # Print each row in LaTeX format
    for (model, error_type, dataset), mean_error, se_error in zip(mean_errors.index, mean_errors, se_errors):
        print(f"{model} & {dataset} & {error_type} & {mean_error:.4f} $\\pm$ {se_error:.4f} \\\\")
    
    print("\\bottomrule")
    print("\\end{tabular}")

if __name__ == "__main__":
    main()


