import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import time
import os
import sys

# Add the code directory to Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from s1_model import CausalMultiTaskDataset
from s2_simulation import simulate_simplified_dgp

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ----- Imbalance Metrics for Domain Adaptation -----

def mmd2_lin(X, t, p=0.5):
    """
    Linear MMD
    """
    # Fix for shape mismatch - flatten the treatment tensor
    t = t.flatten()
    
    X_treated = X[t > 0]
    X_control = X[t < 1]
    
    n_t = X_treated.shape[0]
    n_c = X_control.shape[0]
    
    # Normalize for balanced classes
    scale_factor = 1.0 / (n_t * p * p) + 1.0 / (n_c * (1 - p) * (1 - p))
    
    # Linear kernel
    mean_t = torch.mean(X_treated, dim=0)
    mean_c = torch.mean(X_control, dim=0)
    mmd = torch.sum(torch.square(mean_t - mean_c))
    
    return scale_factor * mmd

def mmd2_rbf(X, t, p=0.5, sigma=1.0):
    """
    Gaussian RBF MMD
    """
    # Fix for shape mismatch - flatten the treatment tensor
    t = t.flatten()
    
    X_treated = X[t > 0]
    X_control = X[t < 1]
    
    n_t = X_treated.shape[0]
    n_c = X_control.shape[0]
    
    # Normalize for balanced classes
    scale_factor = 1.0 / (n_t * p * p) + 1.0 / (n_c * (1 - p) * (1 - p))
    
    # Expand dims for vectorized computation
    Kt = torch.sum(torch.exp(-torch.cdist(X_treated, X_treated)**2 / (2 * sigma**2)))
    Kc = torch.sum(torch.exp(-torch.cdist(X_control, X_control)**2 / (2 * sigma**2)))
    Kct = torch.sum(torch.exp(-torch.cdist(X_treated, X_control)**2 / (2 * sigma**2)))
    
    # Scale by the respective sample sizes
    Kt /= (n_t * n_t)
    Kc /= (n_c * n_c)
    Kct /= (n_t * n_c)
    
    mmd = Kt + Kc - 2 * Kct
    return scale_factor * mmd

def wasserstein(X, t, p=0.5, lam=10, its=10, sq=False, backpropT=False):
    """
    Wasserstein distance (approximate)
    """
    # Fix for shape mismatch - flatten the treatment tensor
    t = t.flatten()
    
    X_treated = X[t > 0]
    X_control = X[t < 1]
    
    n_t = X_treated.shape[0]
    n_c = X_control.shape[0]
    
    # Compute cost matrix
    M = torch.cdist(X_treated, X_control)
    
    # Initialize transport plan
    T = torch.ones((n_t, n_c)) / (n_t * n_c)
    T = T.to(device)
    
    # Sinkhorn iterations
    for _ in range(its):
        # Row normalization
        T = T / (torch.sum(T, dim=1, keepdim=True) + 1e-10) * (1.0/n_t)
        # Column normalization
        T = T / (torch.sum(T, dim=0, keepdim=True) + 1e-10) * (1.0/n_c)
    
    # Compute Wasserstein distance
    if sq:
        return torch.sum(T * M**2), T
    else:
        return torch.sum(T * M), T

# ----- CFRNet Model in PyTorch -----

class CFRNet(nn.Module):
    """
    PyTorch implementation of the Counterfactual Regression Network (CFRNet)
    """
    def __init__(self, config):
        super(CFRNet, self).__init__()
        self.config = config
        
        # Extract network dimensions
        input_dim = config['input_dim']
        rep_dim = config['rep_dim']
        out_dim = config['out_dim']
        
        # Representation network
        rep_layers = []
        prev_dim = input_dim
        
        for i in range(config['n_rep_layers']):
            rep_layers.append(nn.Linear(prev_dim, rep_dim))
            if config['batch_norm']:
                rep_layers.append(nn.BatchNorm1d(rep_dim))
            rep_layers.append(nn.ReLU() if config['nonlin'].lower() == 'relu' else nn.ELU())
            rep_layers.append(nn.Dropout(config['dropout_in']))
            prev_dim = rep_dim
        
        self.representation_net = nn.Sequential(*rep_layers)
        
        # Output networks
        if config['split_output']:
            # Separate networks for treated and control
            self.output_net_treated = self._build_output_net(rep_dim, out_dim)
            self.output_net_control = self._build_output_net(rep_dim, out_dim)
            self.pred_treated = nn.Linear(out_dim, 1)
            self.pred_control = nn.Linear(out_dim, 1)
        else:
            # Combined network
            self.output_net = self._build_output_net(rep_dim + 1, out_dim)  # +1 for treatment indicator
            self.pred = nn.Linear(out_dim, 1)
    
    def _build_output_net(self, input_dim, output_dim):
        layers = []
        prev_dim = input_dim
        
        for i in range(self.config['n_out_layers']):
            layers.append(nn.Linear(prev_dim, output_dim))
            layers.append(nn.ReLU() if self.config['nonlin'].lower() == 'relu' else nn.ELU())
            layers.append(nn.Dropout(self.config['dropout_out']))
            prev_dim = output_dim
        
        return nn.Sequential(*layers)
    
    def forward(self, x, t):
        """
        Forward pass through the network
        
        Args:
            x: Features
            t: Treatment indicators
        """
        # Get representation
        h_rep = self.representation_net(x)
        
        # Normalize representation if specified
        if self.config['normalization'] == 'divide':
            h_rep_norm = h_rep / torch.sqrt(torch.sum(torch.square(h_rep), dim=1, keepdim=True) + 1e-8)
        else:
            h_rep_norm = h_rep
        
        if self.config['split_output']:
            # Use appropriate network based on treatment
            y0 = self.pred_control(self.output_net_control(h_rep_norm))
            y1 = self.pred_treated(self.output_net_treated(h_rep_norm))
            
            # Select based on treatment
            outputs = torch.zeros_like(y0)
            treated_idx = (t > 0).flatten()
            outputs[treated_idx] = y1[treated_idx]
            outputs[~treated_idx] = y0[~treated_idx]
            
            # Store for counterfactual prediction
            self.y0 = y0
            self.y1 = y1
        else:
            # Concatenate representation with treatment
            h_input = torch.cat([h_rep_norm, t], dim=1)
            outputs = self.pred(self.output_net(h_input))
        
        # Store representation for imbalance penalty
        self.h_rep_norm = h_rep_norm
        
        return outputs
    
    def predict_counterfactuals(self, x, t=None):
        """
        Predict counterfactual outcomes for both treatment options
        """
        if t is None:
            t = torch.zeros((x.shape[0], 1), device=x.device)
        
        with torch.no_grad():
            # Get representation
            h_rep = self.representation_net(x)
            
            # Normalize representation if specified
            if self.config['normalization'] == 'divide':
                h_rep_norm = h_rep / torch.sqrt(torch.sum(torch.square(h_rep), dim=1, keepdim=True) + 1e-8)
            else:
                h_rep_norm = h_rep
            
            if self.config['split_output']:
                # Use separate networks
                y0 = self.pred_control(self.output_net_control(h_rep_norm))
                y1 = self.pred_treated(self.output_net_treated(h_rep_norm))
            else:
                # Create tensors for both treatment options
                t0 = torch.zeros((x.shape[0], 1), device=x.device)
                t1 = torch.ones((x.shape[0], 1), device=x.device)
                
                # Predict outcomes
                h_input0 = torch.cat([h_rep_norm, t0], dim=1)
                h_input1 = torch.cat([h_rep_norm, t1], dim=1)
                
                y0 = self.pred(self.output_net(h_input0))
                y1 = self.pred(self.output_net(h_input1))
        
        return y0, y1

# ----- Dataset Classes -----
class IHDPDataset(Dataset):
    def __init__(self):
        print("Loading IHDP dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/IHDP.csv", header=None)
        self.X = temp.iloc[:,5:30].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = 25
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

class EHRDataset(Dataset):
    def __init__(self):
        print("Loading EHR dataset...")
        temp = pd.read_csv("/cbica/home/choii/project/causal_representation/causal_representation_learning/simulation/EHR.csv").head(1000)
        columns_to_keep = [
            "X..BAND.NEUTROPHILS..count", "X..BAND.NEUTROPHILS..last",
            "X..LYMPHOCYTES.MANUAL..last", "X..LYMPHOCYTES..first", "X..MONOCYTES..last",
            "X..NEUTROPHILS.MANUAL..count", "X..LYMPHOCYTES.MANUAL..count",
            "X..LYMPHOCYTES..count", "X..LYMPHOCYTES..first.1", "X..LYMPHOCYTES..last",
            "ALKALINE.PHOSPHATASE..first", "ALKALINE.PHOSPHATASE..last",
            "ALPHA.1.FETOPROTEINS..first", "ALPHA.1.FETOPROTEINS..last",
            "ALT..count", "ALT..first", "ALT..last", "ANION.GAP..last",
            "AST..first", "AST..last", "BILIRUBIN.DIRECT..count", "BILIRUBIN.INDIRECT..last"]
        self.X = temp[columns_to_keep].to_numpy(dtype=np.float64)
        self.A = np.random.binomial(1, 0.5, size=len(self.X))
        self.n = len(self.X)
        self.d = len(columns_to_keep)
        print(f"Loaded {self.n} samples with {self.d} features")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        return self.X[idx], self.A[idx]

# ----- Utility Functions -----
def generate_outcomes(X, A, num_outcomes, data_type):
    """Generate outcomes using the DGP from s2_simulation.py"""
    from s2_simulation import simulate_simplified_dgp
    data = simulate_simplified_dgp(n=len(X), d=X.shape[1], K=num_outcomes,
                                 utility='or',
                                 weights=None,
                                 data=data_type)
    return data['Y']

def evaluate_cfr_model(model, X_train, X_test, Y_train, Y_test, A_train, A_test):
    """Evaluate CFR model performance on both training and test data"""
    model.eval()
    with torch.no_grad():
        # Move data to device
        X_train_tensor = torch.FloatTensor(X_train).to(device)
        X_test_tensor = torch.FloatTensor(X_test).to(device)
        Y_train_tensor = torch.FloatTensor(Y_train).to(device)
        Y_test_tensor = torch.FloatTensor(Y_test).to(device)
        A_train_tensor = torch.FloatTensor(A_train).view(-1, 1).to(device)
        A_test_tensor = torch.FloatTensor(A_test).view(-1, 1).to(device)
        
        # Get predictions for both treatments on training data
        y0_train_pred, y1_train_pred = model.predict_counterfactuals(X_train_tensor)
        ate_train_pred = torch.mean(y1_train_pred - y0_train_pred, dim=0)
        
        # Calculate true ATE using direct comparison (if available) or IPW
        true_y1_train = Y_train_tensor[A_train_tensor.flatten() > 0]
        true_y0_train = Y_train_tensor[A_train_tensor.flatten() < 1]
        if len(true_y1_train) > 0 and len(true_y0_train) > 0:
            ate_train_true = torch.mean(true_y1_train, dim=0) - torch.mean(true_y0_train, dim=0)
        else:
            p_treated = torch.mean(A_train_tensor)
            weights = A_train_tensor / p_treated - (1 - A_train_tensor) / (1 - p_treated)
            ate_train_true = torch.mean(weights * Y_train_tensor, dim=0)
        
        # Calculate factual MSE
        factual_preds_train = torch.where(
            A_train_tensor > 0,
            y1_train_pred,
            y0_train_pred
        )
        factual_mse_train = F.mse_loss(factual_preds_train, Y_train_tensor)
        
        # Get predictions for both treatments on test data
        y0_test_pred, y1_test_pred = model.predict_counterfactuals(X_test_tensor)
        ate_test_pred = torch.mean(y1_test_pred - y0_test_pred, dim=0)
        
        # Calculate true ATE for test data
        true_y1_test = Y_test_tensor[A_test_tensor.flatten() > 0]
        true_y0_test = Y_test_tensor[A_test_tensor.flatten() < 1]
        if len(true_y1_test) > 0 and len(true_y0_test) > 0:
            ate_test_true = torch.mean(true_y1_test, dim=0) - torch.mean(true_y0_test, dim=0)
        else:
            p_treated_test = torch.mean(A_test_tensor)
            weights_test = A_test_tensor / p_treated_test - (1 - A_test_tensor) / (1 - p_treated_test)
            ate_test_true = torch.mean(weights_test * Y_test_tensor, dim=0)
        
        # Calculate factual MSE for test data
        factual_preds_test = torch.where(
            A_test_tensor > 0,
            y1_test_pred,
            y0_test_pred
        )
        factual_mse_test = F.mse_loss(factual_preds_test, Y_test_tensor)
        
        # Calculate errors
        ate_train_error = torch.mean(torch.abs(ate_train_pred - ate_train_true)).item()
        ate_test_error = torch.mean(torch.abs(ate_test_pred - ate_test_true)).item()
        
        # Print detailed metrics
        print(f"Factual MSE (Train): {factual_mse_train.item():.4f}")
        print(f"Factual MSE (Test): {factual_mse_test.item():.4f}")
        print(f"True Train ATE: {ate_train_true.detach().cpu().numpy()}")
        print(f"Pred Train ATE: {ate_train_pred.detach().cpu().numpy()}")
        print(f"True Test ATE: {ate_test_true.detach().cpu().numpy()}")
        print(f"Pred Test ATE: {ate_test_pred.detach().cpu().numpy()}")
    
    return {
        'in_sample_ate_error': ate_train_error,
        'out_of_sample_ate_error': ate_test_error,
        'factual_mse_train': factual_mse_train.item(),
        'factual_mse_test': factual_mse_test.item()
    }

# ----- Training Loop -----
def train_cfr_model(model, train_loader, config, device):
    print("Training CFR model...")
    
    # Set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    
    # Training loop
    for epoch in range(config['num_epochs']):
        model.train()
        running_loss = 0.0
        
        for i, (inputs, treatments, targets) in enumerate(train_loader):
            inputs = inputs.to(device)
            treatments = treatments.view(-1, 1).to(device)
            targets = targets.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs, treatments)
            
            # Compute prediction loss
            if config['loss'] == 'mse':
                pred_loss = F.mse_loss(outputs, targets)
            elif config['loss'] == 'l1':
                pred_loss = F.l1_loss(outputs, targets)
            
            # Compute imbalance penalty
            h_rep_norm = model.h_rep_norm
            if config['imb_fun'] == 'mmd2_lin':
                imb_loss = config['p_alpha'] * mmd2_lin(h_rep_norm, treatments, p=0.5)
            elif config['imb_fun'] == 'mmd2_rbf':
                imb_loss = config['p_alpha'] * mmd2_rbf(h_rep_norm, treatments, p=0.5, sigma=config['rbf_sigma'])
            elif config['imb_fun'] == 'wasserstein':
                imb_dist, _ = wasserstein(h_rep_norm, treatments, p=0.5)
                imb_loss = config['p_alpha'] * imb_dist
            else:
                imb_loss = torch.tensor(0.0, device=device)
            
            # Compute total loss
            loss = pred_loss + imb_loss
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Print epoch statistics
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1}/{config['num_epochs']}, Loss: {running_loss/len(train_loader):.4f}")
    
    print("Training complete")
    return model

def run_experiment(dataset_name, config):
    """
    Run experiment for a specific dataset
    
    Args:
        dataset_name (str): Name of the dataset ('synthetic', 'IHDP', 'EHR')
        config (dict): Configuration for the experiment
    """
    start_time = time.time()
    print(f"\nStarting experiment for {dataset_name}")
    
    # Load or generate dataset based on dataset_name
    if dataset_name == 'synthetic':
        print("Generating synthetic data...")
        data = simulate_simplified_dgp(n=1000, d=10, K=1, 
                                     utility='or', 
                                     weights=None)
        X, A, Y = data['X'], data['A'], data['Y']
    elif dataset_name == 'IHDP':
        dataset = IHDPDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 1, 'IHDP')
    elif dataset_name == 'EHR':
        dataset = EHRDataset()
        X, A = dataset.X, dataset.A
        Y = generate_outcomes(X, A, 1, 'EHR')
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Normalize features
    X_mean = np.mean(X, axis=0)
    X_std = np.std(X, axis=0)
    X_std[X_std == 0] = 1.0  # Prevent division by zero
    X_normalized = (X - X_mean) / X_std
    
    # Split data into train and test (with stratification by treatment)
    X_train, X_test, A_train, A_test, Y_train, Y_test = train_test_split(
        X_normalized, A, Y, test_size=0.2, random_state=42, stratify=A
    )
    
    # Create dataset and dataloader
    train_dataset = CausalMultiTaskDataset(X_train, A_train, Y_train)
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    
    # Update config with dataset dimensions
    config['input_dim'] = X.shape[1]
    
    # Create and train CFR model
    model = CFRNet(config).to(device)
    model = train_cfr_model(model, train_loader, config, device)
    
    # Evaluate model
    metrics = evaluate_cfr_model(model, X_train, X_test, Y_train, Y_test, A_train, A_test)
    
    # Results
    results = {
        'Model': ['CFRNet', 'CFRNet'],
        'Error_Type': ['In-Sample', 'Out-of-Sample'],
        'ATE_Error': [metrics['in_sample_ate_error'], metrics['out_of_sample_ate_error']],
        'Dataset': [dataset_name, dataset_name]
    }
    
    end_time = time.time()
    print(f"Experiment completed in {end_time - start_time:.2f} seconds")
    
    return results, model

def main():
    # Default configuration
    default_config = {
        'rep_dim': 64,            # Representation dimension
        'out_dim': 32,            # Output dimension
        'n_rep_layers': 2,        # Number of layers in representation network
        'n_out_layers': 1,        # Number of layers in output network
        'batch_norm': True,       # Use batch normalization
        'nonlin': 'relu',         # Activation function (relu or elu)
        'dropout_in': 0.1,        # Dropout rate for representation network
        'dropout_out': 0.1,       # Dropout rate for output network
        'normalization': None,    # Normalization method (None, 'bn_fixed', or 'divide')
        'split_output': True,     # Whether to use separate output networks for treated and control
        'imb_fun': 'mmd2_rbf',    # Imbalance penalty ('mmd2_lin', 'mmd2_rbf', or 'wasserstein')
        'rbf_sigma': 1.0,         # Bandwidth parameter for RBF kernel
        'p_alpha': 1.0,           # Weight for imbalance penalty
        'loss': 'mse',            # Loss function ('mse' or 'l1')
        'learning_rate': 0.001,   # Learning rate
        'weight_decay': 0.0001,   # Weight decay (L2 regularization)
        'batch_size': 64,         # Batch size
        'num_epochs': 50          # Number of epochs
    }
    
    # Define datasets to test
    datasets = ['synthetic', 'IHDP', 'EHR']
    
    # Initialize results dictionary
    results = {
        'Model': [],
        'Error_Type': [],
        'ATE_Error': [],
        'Dataset': []
    }
    
    # Clear any previous results
    if os.path.exists('/cbica/home/choii/project/causal_representation/causal_representation_learning/cfrnet_results.csv'):
        os.remove('/cbica/home/choii/project/causal_representation/causal_representation_learning/cfrnet_results.csv')
    
    total_start_time = time.time()
    
    for dataset in datasets:
        print(f"\nRunning experiment for {dataset}")
        result, _ = run_experiment(dataset, default_config.copy())
        
        # Add results to main results dictionary
        for key in results:
            results[key].extend(result[key])
        
        # Save intermediate results
        results_df = pd.DataFrame(results)
        results_df.to_csv('/cbica/home/choii/project/causal_representation/causal_representation_learning/cfrnet_results.csv', index=False)
    
    total_end_time = time.time()
    print(f"\nTotal execution time: {(total_end_time - total_start_time)/3600:.2f} hours")
    
    # Print final results
    print("\nResults:")
    print(results_df)
    
    # Print summary statistics in LaTeX table format
    print("\nSummary Statistics:")
    grouped = results_df.groupby(["Model", "Error_Type", "Dataset"])["ATE_Error"]
    mean_errors = grouped.mean()
    se_errors = grouped.std(ddof=1) / np.sqrt(grouped.count())
    
    # Create LaTeX table
    print("\\begin{tabular}{lll}")
    print("\\toprule")
    print("Model & Type & ATE Error \\\\")
    print("\\midrule")
    
    # Print each row in LaTeX format
    for (model, error_type, dataset), mean_error, se_error in zip(mean_errors.index, mean_errors, se_errors):
        print(f"{model} & {error_type} & {mean_error:.4f} $\\pm$ {se_error:.4f} \\\\")
    
    print("\\bottomrule")
    print("\\end{tabular}")

if __name__ == "__main__":
    main()
