import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import time
import os
import sys
import logging

# Add the code directory to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Import directly from cmde_folder without modifications
from cmde_folder.model.cmick import CMICK
from cmde_folder.model.ick import ICK
from cmde_folder.kernels.nn import ImplicitDenseNetKernel
from cmde_folder.utils.train import Trainer
from cmde_folder.utils.losses import FactualMSELoss

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ----- Dataset Classes -----
class IHDPDataset(Dataset):
    def __init__(self):
        print("Loading IHDP dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/IHDP.csv", header=None)
        self.X = temp.iloc[:,5:30].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = 25
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

class EHRDataset(Dataset):
    def __init__(self):
        print("Loading EHR dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/EHR.csv").head(1000)
        columns_to_keep = [
            "X..BAND.NEUTROPHILS..count", "X..BAND.NEUTROPHILS..last",
            "X..LYMPHOCYTES.MANUAL..last", "X..LYMPHOCYTES..first", "X..MONOCYTES..last",
            "X..NEUTROPHILS.MANUAL..count", "X..LYMPHOCYTES.MANUAL..count",
            "X..LYMPHOCYTES..count", "X..LYMPHOCYTES..first.1", "X..LYMPHOCYTES..last",
            "ALKALINE.PHOSPHATASE..first", "ALKALINE.PHOSPHATASE..last",
            "ALPHA.1.FETOPROTEINS..first", "ALPHA.1.FETOPROTEINS..last",
            "ALT..count", "ALT..first", "ALT..last", "ANION.GAP..last",
            "AST..first", "AST..last", "BILIRUBIN.DIRECT..count", "BILIRUBIN.INDIRECT..last"]
        self.X = temp[columns_to_keep].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = len(columns_to_keep)
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

# ----- Utility Functions -----
def generate_outcomes(X, A, data_type):
    """Generate outcomes using the DGP from s2_simulation.py"""
    from s2_simulation import simulate_simplified_dgp
    data = simulate_simplified_dgp(n=len(X), d=X.shape[1], K=1,  # Use single outcome
                                 utility='or',
                                 weights=None,
                                 data=data_type)
    return data['Y']

def evaluate_model(model, X_train, X_test, Y_train, Y_test, A_train, A_test):
    """Evaluate model performance on both training and test data"""
    model.eval()
    with torch.no_grad():
        # Move data to device
        X_train_tensor = torch.FloatTensor(X_train).to(device)
        X_test_tensor = torch.FloatTensor(X_test).to(device)
        Y_train_tensor = torch.FloatTensor(Y_train).to(device)
        Y_test_tensor = torch.FloatTensor(Y_test).to(device)
        A_train_tensor = torch.FloatTensor(A_train).to(device)
        A_test_tensor = torch.FloatTensor(A_test).to(device)
        
        # Predict counterfactuals
        # Get model outputs - safely handle different output formats
        train_outputs = model([X_train_tensor])
        test_outputs = model([X_test_tensor])
        
        # Check if train_outputs is a tuple (both potential outcomes) or tensor (need to separate)
        if isinstance(train_outputs, tuple) and len(train_outputs) == 2:
            control_train, treatment_train = train_outputs
        else:
            # Assume output has shape [batch_size, 2] with control and treatment outcomes
            control_train = train_outputs[:, 0].unsqueeze(1)
            treatment_train = train_outputs[:, 1].unsqueeze(1)
            
        ate_train_pred = torch.mean(treatment_train - control_train).item()
        
        # Calculate true ATE 
        ate_train_true = torch.mean(Y_train_tensor[A_train_tensor == 1]) - torch.mean(Y_train_tensor[A_train_tensor == 0])
        ate_train_true = ate_train_true.item()
        
        # Test data evaluation
        if isinstance(test_outputs, tuple) and len(test_outputs) == 2:
            control_test, treatment_test = test_outputs
        else:
            # Assume output has shape [batch_size, 2] with control and treatment outcomes
            control_test = test_outputs[:, 0].unsqueeze(1)
            treatment_test = test_outputs[:, 1].unsqueeze(1)
            
        ate_test_pred = torch.mean(treatment_test - control_test).item()
        
        # Calculate true ATE for test data
        ate_test_true = torch.mean(Y_test_tensor[A_test_tensor == 1]) - torch.mean(Y_test_tensor[A_test_tensor == 0])
        ate_test_true = ate_test_true.item()
        
        # Calculate errors
        ate_train_error = abs(ate_train_pred - ate_train_true)
        ate_test_error = abs(ate_test_pred - ate_test_true)
        
        # Calculate factual MSE
        factual_preds_train = torch.where(A_train_tensor.view(-1, 1) == 0, control_train, treatment_train)
        factual_mse_train = torch.nn.functional.mse_loss(factual_preds_train, Y_train_tensor.view(-1, 1)).item()
        
        factual_preds_test = torch.where(A_test_tensor.view(-1, 1) == 0, control_test, treatment_test)
        factual_mse_test = torch.nn.functional.mse_loss(factual_preds_test, Y_test_tensor.view(-1, 1)).item()
        
        # Print results
        print(f"Factual MSE (Train): {factual_mse_train:.4f}")
        print(f"Factual MSE (Test): {factual_mse_test:.4f}")
        print(f"True Train ATE: {ate_train_true:.4f}")
        print(f"Pred Train ATE: {ate_train_pred:.4f}")
        print(f"True Test ATE: {ate_test_true:.4f}")
        print(f"Pred Test ATE: {ate_test_pred:.4f}")
        
    return {
        'in_sample_ate_error': ate_train_error,
        'out_of_sample_ate_error': ate_test_error,
        'factual_mse_train': factual_mse_train,
        'factual_mse_test': factual_mse_test
    }

def create_cmde_model(input_dim):
    """Create a standard CMDE model without modifications"""
    # Create ICK objects with the kernel components
    # Using the exact same structure and hyperparameters as in cmde_folder examples
    control_ick = ICK(
        kernel_assignment=['ImplicitDenseNetKernel'],
        kernel_params={'ImplicitDenseNetKernel': {
            'input_dim': input_dim,
            'latent_feature_dim': 64,
            'num_blocks': 2,
            'num_layers_per_block': 2,
            'num_units': 128,
            'activation': 'relu',
            'dropout_ratio': 0.1
        }}
    )
    
    treatment_ick = ICK(
        kernel_assignment=['ImplicitDenseNetKernel'],
        kernel_params={'ImplicitDenseNetKernel': {
            'input_dim': input_dim,
            'latent_feature_dim': 64,
            'num_blocks': 2,
            'num_layers_per_block': 2,
            'num_units': 128,
            'activation': 'relu',
            'dropout_ratio': 0.1
        }}
    )
    
    shared_ick = ICK(
        kernel_assignment=['ImplicitDenseNetKernel'],
        kernel_params={'ImplicitDenseNetKernel': {
            'input_dim': input_dim,
            'latent_feature_dim': 64,
            'num_blocks': 2,
            'num_layers_per_block': 2,
            'num_units': 128,
            'activation': 'relu',
            'dropout_ratio': 0.1
        }}
    )
    
    # Create CMICK model directly as in the original code
    model = CMICK(
        control_components=[control_ick],
        treatment_components=[treatment_ick],
        shared_components=[shared_ick],
        output_binary=False
    )
    
    return model

class CausalDataset(Dataset):
    """Simple dataset for causal inference with single outcome"""
    def __init__(self, X, A, Y):
        self.X = torch.FloatTensor(X)
        self.A = torch.FloatTensor(A)
        self.Y = torch.FloatTensor(Y)
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        # Format expected by original CMDE: [X], Y_factual, A
        # The model uses [X] (features in a list) as input
        return [self.X[idx]], self.Y[idx], self.A[idx]

# Custom trainer class that handles the group argument correctly
class CustomTrainer(Trainer):
    """Custom Trainer that properly handles the group argument required by FactualMSELoss"""
    
    def __init__(self, model, data_generators, optim, optim_params, lr_scheduler=None, 
                 model_save_dir=None, model_name='model.pt', loss_fn=torch.nn.MSELoss(), 
                 device=torch.device('cpu'), validation=True, epochs=100, patience=float('inf'), 
                 verbose=0, stop_criterion='loss', logger=logging.getLogger("Trainer")):
        # Set patience to infinity to disable early stopping
        super(CustomTrainer, self).__init__(model, data_generators, optim, optim_params, 
                                           lr_scheduler, model_save_dir, model_name, loss_fn, 
                                           device, validation, epochs, patience, verbose, 
                                           stop_criterion, logger)
    
    def _train_step(self):
        """
        Override _train_step to properly handle the group argument
        """
        y_train_pred = torch.empty(0).to(self.device)
        y_train_true = torch.empty(0).to(self.device)
        y_train_group = torch.empty(0).to(self.device)
        
        self.model.to(self.device)
        self.model.train()
        
        for step, batch in enumerate(self.data_generators['train']):
            data, target, group = batch[0], batch[1], batch[2]  # Unpack batch with group
            
            # Move to device
            if isinstance(data, list):
                data = list(map(lambda x: x.to(self.device), data))
            else:
                data = data.to(self.device)
            target = target.to(self.device)
            group = group.to(self.device)
            
            # Zero the gradients
            self.optimizer.zero_grad()
            
            # Forward and backward pass
            output = self.model(data).float()
            loss = self.loss_fn(output, target, group)  # Pass group argument
            loss.backward()
            self.optimizer.step()
            
            # Record the predictions
            y_train_pred = torch.cat((y_train_pred, output), dim=0)
            y_train_true = torch.cat((y_train_true, target), dim=0)
            y_train_group = torch.cat((y_train_group, group), dim=0)
        
        train_loss = self.loss_fn(y_train_pred, y_train_true, y_train_group).item()
        return train_loss, step
    
    def validate(self):
        """
        Override validate method to properly handle the group argument
        """
        self.y_val_pred = torch.empty(0).to(self.device)
        self.y_val_true = torch.empty(0).to(self.device)
        self.y_val_group = torch.empty(0).to(self.device)
        
        self.model.eval()

        key = 'train' if not self.validation else ('val' if self.data_generators.get('val') is not None else 'test')
        
        with torch.no_grad():
            for batch in self.data_generators[key]:
                data, target, group = batch[0], batch[1], batch[2]  # Unpack batch with group
                
                # Move to device
                if isinstance(data, list):
                    data = list(map(lambda x: x.to(self.device), data))
                else:
                    data = data.to(self.device)
                target = target.to(self.device)
                group = group.to(self.device)
                
                output = self.model(data).float()
                self.y_val_pred = torch.cat((self.y_val_pred, output), dim=0)
                self.y_val_true = torch.cat((self.y_val_true, target), dim=0)
                self.y_val_group = torch.cat((self.y_val_group, group), dim=0)
        
        val_loss = self.loss_fn(self.y_val_pred, self.y_val_true, self.y_val_group).item()
        return val_loss

def run_experiment(dataset_name, num_epochs=50):
    """
    Run experiment for a specific dataset
    
    Args:
        dataset_name (str): Name of the dataset ('synthetic', 'IHDP', 'EHR')
        num_epochs (int): Number of epochs to train the model
    """
    start_time = time.time()
    print(f"\nStarting experiment for {dataset_name}")
    
    # Load or generate dataset based on dataset_name
    if dataset_name == 'synthetic':
        print("Generating synthetic data...")
        from s2_simulation import simulate_simplified_dgp
        data = simulate_simplified_dgp(n=1000, d=10, K=1,  # Single outcome 
                                     utility='or', 
                                     weights=None)
        X, A, Y = data['X'], data['A'], data['Y']
    elif dataset_name == 'IHDP':
        dataset = IHDPDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 'IHDP')
    elif dataset_name == 'EHR':
        dataset = EHRDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 'EHR')
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Normalize features for better training
    X_mean = np.mean(X, axis=0)
    X_std = np.std(X, axis=0)
    X_std[X_std == 0] = 1.0  # Prevent division by zero
    X_normalized = (X - X_mean) / X_std
    
    # Split data into train and test
    X_train, X_test, A_train, A_test, Y_train, Y_test = train_test_split(
        X_normalized, A, Y, test_size=0.2, random_state=42, stratify=A
    )
    
    # Initialize results dictionary
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': []
    }
    
    # Train and evaluate CMDE model
    input_dim = X.shape[1]
    
    print("Training CMDE model...")
    train_dataset = CausalDataset(X_train, A_train, Y_train)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    
    # Create and train CMDE model
    model = create_cmde_model(input_dim)
    model = model.to(device)
    
    # Use our CustomTrainer that correctly handles the group argument
    trainer = CustomTrainer(
        model=model,
        data_generators={'train': train_loader},
        optim='adam',
        optim_params={'lr': 1e-3},
        loss_fn=FactualMSELoss(),  # Using the original loss function
        device=device,
        epochs=num_epochs,
        validation=False
    )
    
    # Train model
    trainer.train()
    
    # Evaluate model
    metrics = evaluate_model(model, X_train, X_test, Y_train, Y_test, A_train, A_test)
    
    # Add CMDE results
    results['Model'].extend(['CMDE', 'CMDE'])
    results['Error_Type'].extend(['In-Sample', 'Out-of-Sample'])
    results['ATE_Error'].extend([metrics['in_sample_ate_error'], metrics['out_of_sample_ate_error']])
    results['Dataset'].extend([dataset_name, dataset_name])
    
    end_time = time.time()
    print(f"Experiment completed in {end_time - start_time:.2f} seconds")
    
    return results

def main():
    # Define datasets to test
    datasets = ['synthetic', 'IHDP', 'EHR']
    
    # Clear any previous results file to avoid confusion
    results_file = '/cbica/home/choii/project/causal_representation/causal_representation_learning/cmde_results.csv'
    if os.path.exists(results_file):
        os.remove(results_file)
    
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': []
    }
    
    total_start_time = time.time()
    
    # Run just one trial
    num_epochs = 50
    
    for dataset in datasets:
        print(f"\nRunning experiment for {dataset}")
        result = run_experiment(dataset, num_epochs=num_epochs)
        
        # Add results to main results dictionary
        for key in results:
            results[key].extend(result[key])
        
        # Save intermediate results
        results_df = pd.DataFrame(results)
        results_df.to_csv(results_file, index=False)
    
    total_end_time = time.time()
    print(f"\nTotal execution time: {(total_end_time - total_start_time)/3600:.2f} hours")
    
    # Print final results
    print("\nResults:")
    results_df = pd.DataFrame(results)
    print(results_df)
    
    # Print summary statistics in LaTeX table format
    print("\nSummary Statistics:")
    grouped = results_df.groupby(["Model", "Error_Type", "Dataset"])["ATE_Error"]
    mean_errors = grouped.mean()
    se_errors = grouped.std(ddof=1) / np.sqrt(grouped.count())
    
    # Create LaTeX table
    print("\\begin{tabular}{lll}")
    print("\\toprule")
    print("Model & Type & ATE Error \\\\")
    print("\\midrule")
    
    # Print each row in LaTeX format
    for (model, error_type, dataset), mean_error, se_error in zip(mean_errors.index, mean_errors, se_errors):
        print(f"{model} & {error_type} & {mean_error:.4f} $\\pm$ {se_error:.4f} \\\\")
    
    print("\\bottomrule")
    print("\\end{tabular}")

if __name__ == "__main__":
    main()
