import torch
import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split
import gpytorch
import time
import os
import sys
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.decomposition import PCA

# Add the code directory to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from s2_simulation import simulate_simplified_dgp

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ----- Multi-Task Gaussian Process Model -----

class CMGPModel(gpytorch.models.ExactGP):
    """
    PyTorch implementation of Causal Multi-task Gaussian Process (CMGP) model
    using GPyTorch for efficient computation.
    """
    def __init__(self, train_x, train_y, likelihood, num_tasks=2):
        """
        Initialize CMGP model with training data
        
        Args:
            train_x: Training features with task indices concatenated as last column
            train_y: Training targets
            likelihood: GPyTorch likelihood object
            num_tasks: Number of tasks (default: 2 for control and treatment)
        """
        # Ensure train_y is 1D
        train_y = train_y.view(-1)
        super().__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.shared_covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(ard_num_dims=train_x.shape[1] - 1)
        )
        self.task_covar_module = gpytorch.kernels.IndexKernel(num_tasks=num_tasks, rank=1)
    
    def forward(self, x):
        """
        Forward pass through the CMGP model
        
        Args:
            x: Input tensor with task indices in the last column
        """
        features, task_indices = x[:, :-1], x[:, -1].long()
        mean_x = self.mean_module(features)
        covar_x = self.shared_covar_module(features).mul(self.task_covar_module(task_indices))
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# ----- Risk-Based Empirical Bayes Objective -----
class RiskBasedObjective:
    """
    Risk-based empirical Bayes objective for CMGP model
    """
    def __init__(self, model, likelihood, x_train, y_train, t_train):
        self.model = model
        self.likelihood = likelihood
        self.x_train = x_train
        self.y_train = y_train
        self.t_train = t_train
        
        # Create counterfactual inputs
        t_cf = 1 - t_train
        self.x_cf = torch.cat([x_train[:, :-1], t_cf.unsqueeze(1)], dim=1)
        
        # Create a separate model for counterfactual predictions
        self.cf_model = CMGPModel(self.x_cf, y_train, likelihood)
        self.cf_model.load_state_dict(model.state_dict())  # Copy parameters from main model

    def compute_empirical_factual_loss(self):
        """Compute the empirical factual loss"""
        with torch.no_grad():
            preds = self.likelihood(self.model(self.x_train))
        return torch.mean((preds.mean - self.y_train) ** 2)

    def compute_counterfactual_variance(self):
        """Compute the counterfactual variance"""
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            pred_cf = self.likelihood(self.cf_model(self.x_cf))
        return torch.sum(pred_cf.variance)

    def __call__(self):
        """Compute the total objective value"""
        # Update counterfactual model parameters
        self.cf_model.load_state_dict(self.model.state_dict())
        
        # Compute losses
        factual_loss = self.compute_empirical_factual_loss()
        cf_variance = self.compute_counterfactual_variance()
        
        return factual_loss + cf_variance

# ----- Trainer -----
class CMGPTrainer:
    """
    Trainer class for CMGP model with risk-based empirical Bayes
    """
    def __init__(self, model, likelihood, x_train, y_train, t_train, 
                 lr=0.01, weight_decay=1e-4, max_iter=1000):
        self.model = model
        self.likelihood = likelihood
        self.x_train = x_train
        self.y_train = y_train
        self.t_train = t_train
        self.lr = lr
        self.weight_decay = weight_decay
        self.max_iter = max_iter

    def train(self):
        """Train the CMGP model"""
        self.model.train()
        self.likelihood.train()
        
        optimizer = torch.optim.Adam(
            self.model.parameters(), 
            lr=self.lr, 
            weight_decay=self.weight_decay
        )
        
        objective = RiskBasedObjective(
            self.model, self.likelihood, 
            self.x_train, self.y_train, self.t_train
        )
        
        # Initialize progress bar
        iterator = tqdm(range(self.max_iter))
        
        for i in iterator:
            optimizer.zero_grad()
            loss = objective()
            loss.backward()
            optimizer.step()
            
            # Update progress bar
            if i % 10 == 0:
                iterator.set_postfix(loss=loss.item())
        
        self.model.eval()
        self.likelihood.eval()
        return self.model

# ----- Dataset Classes -----
class IHDPDataset:
    def __init__(self):
        print("Loading IHDP dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/IHDP.csv", header=None)
        self.X = temp.iloc[:,5:30].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = 25
        print(f"Loaded {self.n} samples with {self.d} features")

class EHRDataset:
    def __init__(self):
        print("Loading EHR dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/EHR.csv").head(1000)
        columns_to_keep = [
            "X..BAND.NEUTROPHILS..count", "X..BAND.NEUTROPHILS..last",
            "X..LYMPHOCYTES.MANUAL..last", "X..LYMPHOCYTES..first", "X..MONOCYTES..last",
            "X..NEUTROPHILS.MANUAL..count", "X..LYMPHOCYTES.MANUAL..count",
            "X..LYMPHOCYTES..count", "X..LYMPHOCYTES..first.1", "X..LYMPHOCYTES..last",
            "ALKALINE.PHOSPHATASE..first", "ALKALINE.PHOSPHATASE..last",
            "ALPHA.1.FETOPROTEINS..first", "ALPHA.1.FETOPROTEINS..last",
            "ALT..count", "ALT..first", "ALT..last", "ANION.GAP..last",
            "AST..first", "AST..last", "BILIRUBIN.DIRECT..count", "BILIRUBIN.INDIRECT..last"]
        self.X = temp[columns_to_keep].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = len(columns_to_keep)
        print(f"Loaded {self.n} samples with {self.d} features")

# ----- Utility Functions -----
def generate_outcomes(X, A, data_type):
    """Generate outcomes using the DGP from s2_simulation.py"""
    from s2_simulation import simulate_simplified_dgp
    data = simulate_simplified_dgp(n=len(X), d=X.shape[1], K=1,
                                 utility='or',
                                 weights=None,
                                 data=data_type)
    return data['Y']

def predict_ite(model, likelihood, X):
    """Predict individual treatment effects"""
    X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
    X_control = torch.cat([X_tensor, torch.zeros(X_tensor.shape[0], 1, device=device)], dim=1)
    X_treat = torch.cat([X_tensor, torch.ones(X_tensor.shape[0], 1, device=device)], dim=1)
    
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        # Get predictions
        control_dist = likelihood(model(X_control))
        treat_dist = likelihood(model(X_treat))
        
        # Extract means
        mean_control = control_dist.mean
        mean_treat = treat_dist.mean
        
        # Ensure shapes match
        if mean_control.dim() > 1:
            mean_control = mean_control.squeeze(-1)
        if mean_treat.dim() > 1:
            mean_treat = mean_treat.squeeze(-1)
    
    return (mean_treat - mean_control).cpu().numpy()

def evaluate_cmgp_model(model, likelihood, X_train, X_test, Y_train, Y_test, A_train, A_test):
    """Evaluate CMGP model performance on training and test data"""
    # Convert targets to tensors and ensure correct shape
    Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32).to(device)
    Y_test_tensor = torch.tensor(Y_test, dtype=torch.float32).to(device)
    
    if Y_train_tensor.dim() > 1:
        Y_train_tensor = Y_train_tensor.squeeze(-1)
    if Y_test_tensor.dim() > 1:
        Y_test_tensor = Y_test_tensor.squeeze(-1)
    
    # Predict ITEs for training and test data
    ite_train_pred = predict_ite(model, likelihood, X_train)
    ite_test_pred = predict_ite(model, likelihood, X_test)
    
    # Calculate true ATE
    ate_train_true = torch.mean(Y_train_tensor[A_train == 1]) - torch.mean(Y_train_tensor[A_train == 0])
    ate_test_true = torch.mean(Y_test_tensor[A_test == 1]) - torch.mean(Y_test_tensor[A_test == 0])
    
    # Calculate predicted ATE
    ate_train_pred = np.mean(ite_train_pred)
    ate_test_pred = np.mean(ite_test_pred)
    
    # Calculate errors
    ate_train_error = np.abs(ate_train_pred - ate_train_true.item())
    ate_test_error = np.abs(ate_test_pred - ate_test_true.item())
    
    # Print detailed metrics
    print(f"True Train ATE: {ate_train_true.item():.4f}")
    print(f"Pred Train ATE: {ate_train_pred:.4f}")
    print(f"True Test ATE: {ate_test_true.item():.4f}")
    print(f"Pred Test ATE: {ate_test_pred:.4f}")
    
    return {
        'in_sample_ate_error': ate_train_error,
        'out_of_sample_ate_error': ate_test_error
    }

def plot_treatment_effects(model, likelihood, X, A, Y, title="Treatment Effects"):
    """Plot predicted vs true treatment effects"""
    plt.figure(figsize=(10, 6))
    
    # Predict ITEs
    ite_pred = predict_ite(model, likelihood, X)
    
    # Calculate first principal component for visualization
    pca = PCA(n_components=1)
    X_pca = pca.fit_transform(X)
    
    # Plot predicted ITE
    plt.scatter(X_pca, ite_pred, c='blue', alpha=0.5, label='Predicted ITE')
    
    # Calculate true ATE (as a reference)
    ate_true = np.mean(Y[A == 1]) - np.mean(Y[A == 0])
    plt.axhline(y=ate_true, color='r', linestyle='-', label='True ATE')
    
    plt.xlabel('First Principal Component')
    plt.ylabel('Treatment Effect')
    plt.title(title)
    plt.legend()
    plt.savefig(f"{title.replace(' ', '_')}.png")
    plt.close()

def run_experiment(dataset_name, max_iterations=1000):
    """
    Run experiment for a specific dataset
    
    Args:
        dataset_name (str): Name of the dataset ('synthetic', 'IHDP', 'EHR')
        max_iterations (int): Maximum number of optimization iterations
    """
    start_time = time.time()
    print(f"\nStarting experiment for {dataset_name}")
    
    # Load or generate dataset based on dataset_name
    if dataset_name == 'synthetic':
        print("Generating synthetic data...")
        data = simulate_simplified_dgp(n=1000, d=10, K=1, 
                                     utility='or', 
                                     weights=None)
        X, A, Y = data['X'], data['A'], data['Y']
    elif dataset_name == 'IHDP':
        dataset = IHDPDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 'IHDP')
    elif dataset_name == 'EHR':
        dataset = EHRDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 'EHR')
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Normalize features
    X_mean = np.mean(X, axis=0)
    X_std = np.std(X, axis=0)
    X_std[X_std == 0] = 1.0  # Prevent division by zero
    X_normalized = (X - X_mean) / X_std
    
    # Split data into train and test (with stratification by treatment)
    X_train, X_test, A_train, A_test, Y_train, Y_test = train_test_split(
        X_normalized, A, Y, test_size=0.2, random_state=42, stratify=A
    )
    
    # Prepare data for GPyTorch
    x_train_tensor = torch.tensor(np.hstack([X_train, A_train.reshape(-1, 1)]), dtype=torch.float32).to(device)
    y_train_tensor = torch.tensor(Y_train, dtype=torch.float32).to(device)
    t_train_tensor = torch.tensor(A_train, dtype=torch.float32).to(device)
    
    # Create likelihood and model
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    likelihood.to(device)
    
    model = CMGPModel(x_train_tensor, y_train_tensor, likelihood)
    model.to(device)
    
    # Train the model
    trainer = CMGPTrainer(
        model, likelihood, x_train_tensor, y_train_tensor, t_train_tensor,
        max_iter=max_iterations
    )
    trained_model = trainer.train()
    
    # Evaluate model
    metrics = evaluate_cmgp_model(
        trained_model, likelihood, 
        X_train, X_test, Y_train, Y_test, A_train, A_test
    )
    
    # Plot treatment effects
    plot_treatment_effects(
        trained_model, likelihood, 
        X_test, A_test, Y_test, 
        title=f"CMGP on {dataset_name} dataset"
    )
    
    # Results
    results = {
        'Model': ['CMGP', 'CMGP'],
        'Error_Type': ['In-Sample', 'Out-of-Sample'],
        'ATE_Error': [metrics['in_sample_ate_error'], metrics['out_of_sample_ate_error']],
        'Dataset': [dataset_name, dataset_name]
    }
    
    end_time = time.time()
    print(f"Experiment completed in {end_time - start_time:.2f} seconds")
    
    return results

def main():
    # Clear any previous results
    results_file = '/cbica/home/choii/project/causal_representation/causal_representation_learning/cmgp_results.csv'
    if os.path.exists(results_file):
        os.remove(results_file)
    
    # Define datasets to test
    datasets = ['synthetic', 'IHDP', 'EHR']
    
    # Initialize results dictionary
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': []
    }
    
    total_start_time = time.time()
    
    # Run experiments
    for dataset in datasets:
        print(f"\nRunning experiment for {dataset}")
        result = run_experiment(dataset, max_iterations=1000)
        
        # Add results to main results dictionary
        for key in results:
            results[key].extend(result[key])
        
        # Save intermediate results
        results_df = pd.DataFrame(results)
        results_df.to_csv(results_file, index=False)
    
    total_end_time = time.time()
    print(f"\nTotal execution time: {(total_end_time - total_start_time)/3600:.2f} hours")
    
    # Print final results
    print("\nResults:")
    print(results_df)
    
    # Print summary statistics in LaTeX table format
    print("\nSummary Statistics:")
    grouped = results_df.groupby(["Model", "Error_Type", "Dataset"])["ATE_Error"]
    mean_errors = grouped.mean()
    se_errors = grouped.std(ddof=1) / np.sqrt(grouped.count())
    
    # Create LaTeX table
    print("\\begin{tabular}{llll}")
    print("\\toprule")
    print("Model & Dataset & Type & ATE Error \\\\")
    print("\\midrule")
    
    # Print each row in LaTeX format
    for (model, error_type, dataset), mean_error, se_error in zip(mean_errors.index, mean_errors, se_errors):
        print(f"{model} & {dataset} & {error_type} & {mean_error:.4f} $\\pm$ {se_error:.4f} \\\\")
    
    print("\\bottomrule")
    print("\\end{tabular}")

if __name__ == "__main__":
    main()
