import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings
warnings.filterwarnings("ignore")

class AE(torch.nn.Module):
    def __init__(self, input_dim, latent_dim, depth=2, width=0.5, loss_type='mse'):
        super(AE, self).__init__()

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        for i in range(depth):
            if i == (depth - 1):
                self.encoder.append(torch.nn.Linear(int(width * input_dim), latent_dim))
                if loss_type == "nb":
                    self.decoder.append(torch.nn.Linear(int(width * input_dim), input_dim*2))
                else:
                    self.decoder.append(torch.nn.Linear(int(width * input_dim), input_dim))
            else:
                if i == 0:
                    self.encoder.append(torch.nn.Linear(input_dim, int(width * input_dim)))
                    self.decoder.append(torch.nn.Linear(latent_dim, int(width * input_dim)))
                else:
                    self.encoder.append(torch.nn.Linear(int(width * input_dim), int(width * input_dim)))
                    self.decoder.append(torch.nn.Linear(int(width * input_dim), int(width * input_dim)))
                self.encoder.append(torch.nn.ReLU())
                self.decoder.append(torch.nn.ReLU())
        self.decoder.append(torch.nn.ReLU())
    
    def encode(self, x):
        for layer in self.encoder:
            x = layer(x)
        return x
    
    def decode(self, x):
        for layer in self.decoder:
            x = layer(x)
        return x
    
    def forward(self, x):
        # encode
        x = self.encode(x)
        # decode
        x = self.decode(x)
        return x

# loss function
def loss_fn(x, x_hat, loss_type='mse'):
    if loss_type == 'mse':
        return torch.nn.functional.mse_loss(x, x_hat)
    elif loss_type == 'nb':
        # treat the first half of the outputs as the mean and the second half as the dispersion
        mu = x_hat[:, :x.shape[1]]
        theta = x_hat[:, x.shape[1]:] + 1e-6  # add small constant for numerical stability
        # Negative Binomial loss
        #return -torch.mean(torch.sum(
        #    torch.lgamma(theta + x) - torch.lgamma(theta) - torch.lgamma(x + 1) +
        #    theta * (torch.log(theta) - torch.log(theta + mu)) +
        #    x * (torch.log(mu) - torch.log(theta + mu)), dim=1))
        return -torch.mean(
            torch.lgamma(theta + x) - torch.lgamma(theta) - torch.lgamma(x + 1) +
            theta * (torch.log(theta) - torch.log(theta + mu)) +
            x * (torch.log(mu) - torch.log(theta + mu)))
    elif loss_type == 'poisson':
        # treat the output as the mean
        mu = x_hat + 1e-6  # add small constant for numerical stability
        # Poisson loss
        #return -torch.mean(torch.sum(x * torch.log(mu) - mu - torch.lgamma(x + 1), dim=1))
        return -torch.mean(x * torch.log(mu) - mu - torch.lgamma(x + 1))
    else:
        raise ValueError("Unsupported loss type. Use 'mse', 'nb', or 'poisson'.")
    
def train_autoencoder(model, data, loss_type, epochs=100, batch_size=64, lr=1e-3, device='cpu', verbose=True):
    """Train an autoencoder model with the specified loss function.
    
    Args:
        model: The autoencoder model to train
        data: Input data tensor
        loss_type: Type of loss function ('mse', 'poisson', or 'nb')
        epochs: Number of training epochs
        batch_size: Batch size for training
        lr: Learning rate
        device: Device to train on ('cpu' or 'cuda')
        verbose: Whether to print progress
        
    Returns:
        Dictionary containing training history and trained model
    """
    # Move model and data to device
    model = model.to(device)
    data = data.to(device)
    
    # Setup optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Training history
    history = {'loss': []}
    
    # Create data loader
    dataset = torch.utils.data.TensorDataset(data)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        for batch in dataloader:
            # Get batch data
            x = batch[0]
            
            # Forward pass
            x_hat = model(x)
            
            # Compute loss
            loss = loss_fn(x, x_hat, loss_type)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update metrics
            epoch_loss += loss.item()
            num_batches += 1
        
        # Calculate average loss for the epoch
        avg_loss = epoch_loss / num_batches
        history['loss'].append(avg_loss)
        
        # Print progress
        if verbose and (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")
    
    return {'model': model, 'history': history}

def compute_gradient_contributions(model, data, feature_metadata, device='cpu'):
    """Compute gradient-based feature importance for an autoencoder.
    
    This computes how much each input feature affects the latent space by measuring the gradients.
    
    Args:
        model: Trained autoencoder model
        data: Input data tensor
        feature_metadata: DataFrame with feature metadata (type, etc.)
        device: Device to compute on ('cpu' or 'cuda')
        
    Returns:
        DataFrame with feature importance scores
    """
    model.eval()
    data = data.to(device)
    data.requires_grad = True
    
    # Get latent representation
    z = model.encode(data)
    latent_dim = z.shape[1]
    
    # For each latent dimension, compute gradients with respect to input features
    gradients = []
    for i in range(latent_dim):
        model.zero_grad()
        # Take mean across samples for stability
        latent_dim_avg = z[:, i].mean()
        latent_dim_avg.backward(retain_graph=(i < latent_dim - 1))
        # Store absolute gradients (importance is magnitude, not direction)
        gradients.append(data.grad.abs().mean(dim=0).cpu().detach().numpy())
        # Reset gradients for next dimension
        if i < latent_dim - 1:
            data.grad.zero_()
    
    # Average importance across all latent dimensions
    importance = np.mean(gradients, axis=0)
    
    # Create DataFrame with results
    importance_df = pd.DataFrame({
        'feature': feature_metadata['feature'],
        'type': feature_metadata['type'],
        'importance': importance
    })
    
    return importance_df


def compute_regression_scores(model, data, feature_metadata, device='cpu'):
    """Compute how well each input feature can be reconstructed from the latent space.
    
    This uses R² from linear regression to measure how well each feature is encoded.
    
    Args:
        model: Trained autoencoder model
        data: Input data tensor
        feature_metadata: DataFrame with feature metadata (type, etc.)
        device: Device to compute on ('cpu' or 'cuda')
        
    Returns:
        DataFrame with regression R² scores for each feature
    """
    from sklearn.linear_model import LinearRegression
    
    model.eval()
    with torch.no_grad():
        # Get latent representation
        z = model.encode(data.to(device)).cpu().numpy()
        data_np = data.cpu().numpy()
    
    # For each feature, fit a linear regression from latent space
    r2_scores = []
    for i in range(data_np.shape[1]):
        # Linear regression from latent space to feature
        reg = LinearRegression()
        reg.fit(z, data_np[:, i])
        score = reg.score(z, data_np[:, i])
        r2_scores.append(score)
    
    # Create DataFrame with results
    regression_df = pd.DataFrame({
        'feature': feature_metadata['feature'],
        'type': feature_metadata['type'],
        'r2_score': r2_scores
    })
    
    return regression_df


def compute_reconstruction_error(model, data, feature_metadata, device='cpu'):
    """Compute reconstruction error for each feature.
    
    Args:
        model: Trained autoencoder model
        data: Input data tensor
        feature_metadata: DataFrame with feature metadata (type, etc.)
        device: Device to compute on ('cpu' or 'cuda')
        
    Returns:
        DataFrame with MSE for each feature
    """
    model.eval()
    with torch.no_grad():
        # Forward pass to get reconstruction
        data = data.to(device)
        reconstruction = model(data)
        
        # Handle the case for negative binomial where output has double size
        if reconstruction.shape[1] > data.shape[1]:
            reconstruction = reconstruction[:, :data.shape[1]]
        
        # Compute MSE for each feature
        mse = ((reconstruction - data) ** 2).mean(dim=0).cpu().numpy()
    
    # Create DataFrame with results
    error_df = pd.DataFrame({
        'feature': feature_metadata['feature'],
        'type': feature_metadata['type'],
        'mse': mse
    })
    
    return error_df

def prepare_data_for_training(data, simulator):
    """Prepare the data for autoencoder training."""
    
    # Combine data for training
    combined_data = np.concatenate([data['tf_expression'].values, data['gene_expression'].values], axis=1)

    # transform with log1p
    #combined_data = np.log1p(combined_data)
    
    # Convert to PyTorch tensor
    data_tensor = torch.tensor(combined_data, dtype=torch.float32)
    
    # Create feature metadata for all features
    tf_meta = pd.DataFrame({
        'feature': data['tf_expression'].columns,
        'type': 'TF'
    })
    regulated_meta = pd.DataFrame({
        'feature': data['gene_expression'].iloc[:, :simulator.n_regulated_genes].columns,
        'type': 'Regulated Gene'
    })
    nonregulated_meta = pd.DataFrame({
        'feature': data['gene_expression'].iloc[:, simulator.n_regulated_genes:].columns,
        'type': 'Non-regulated Gene'
    })
    combined_feature_metadata = pd.concat([tf_meta, regulated_meta, nonregulated_meta])
    
    return data_tensor, combined_feature_metadata


def run_autoencoder_experiments(data, simulator, latent_dims=[10, 50], loss_types=['mse', 'poisson', 'nb'],
                               epochs=100, batch_size=64, model_depth=2, width_factor=0.5, learning_rate=1e-3):
    """Run autoencoder experiments with different architectures and loss functions.
    
    Args:
        latent_dims: List of latent dimensions to try
        loss_types: List of loss functions to try ('mse', 'poisson', 'nb')
        epochs: Number of training epochs
        batch_size: Batch size for training
        model_depth: Depth of the autoencoder
        width_factor: Width factor for hidden layers as a fraction of input dim
        
    Returns:
        Dictionary with results for each model
    """
    # Prepare data for training
    data_tensor, feature_metadata = prepare_data_for_training(data, simulator)
    input_dim = data_tensor.shape[1]
    
    # Determine device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Store results
    results = {}
    
    # Run experiments for each combination of latent dimension and loss type
    for latent_dim in latent_dims:
        for loss_type in loss_types:
            # Model name for this configuration
            model_name = f"AE_latent{latent_dim}_{loss_type}"
            print(f"\nTraining {model_name}...")
            
            # Create model
            model = AE(input_dim=input_dim, latent_dim=latent_dim, depth=model_depth,
                      width=width_factor, loss_type=loss_type)
            
            # Train model
            train_results = train_autoencoder(
                model=model,
                data=data_tensor,
                loss_type=loss_type,
                epochs=epochs,
                batch_size=batch_size,
                device=device,
                lr=learning_rate,
            )
            
            # Analyze model
            print(f"Analyzing {model_name}...")
            importance_df = compute_gradient_contributions(
                model=train_results['model'],
                data=data_tensor,
                feature_metadata=feature_metadata,
                device=device
            )
            
            regression_df = compute_regression_scores(
                model=train_results['model'],
                data=data_tensor,
                feature_metadata=feature_metadata,
                device=device
            )
            
            error_df = compute_reconstruction_error(
                model=train_results['model'],
                data=data_tensor,
                feature_metadata=feature_metadata,
                device=device
            )
            
            # Visualize results
            #figures = visualize_model_analysis(
            #    importance_df=importance_df,
            #    regression_df=regression_df,
            #    error_df=error_df,
            #    model_name=model_name
            #)
            
            # Store results
            results[model_name] = {
                'model': train_results['model'],
                'history': train_results['history'],
                'importance': importance_df,
                'regression': regression_df,
                'error': error_df,
                #'figures': figures
            }
            
    return results

def compare_loss_functions(results, latent_dim=50):
    """Compare performance of different loss functions.
    
    Args:
        results: Dictionary with experiment results
        latent_dim: Latent dimension to compare across loss functions
        
    Returns:
        Comparison visualizations
    """
    # Filter models with the specified latent dimension
    loss_types = ['mse', 'poisson', 'nb']
    model_names = [f"AE_latent{latent_dim}_{loss}" for loss in loss_types]
    
    # Make sure all models exist in results
    model_names = [name for name in model_names if name in results]
    
    if not model_names:
        print(f"No models found with latent dimension {latent_dim}")
        return
    
    # Extract summary data for comparison
    importance_by_type = []
    r2_by_type = []
    mse_by_type = []
    
    for model_name in model_names:
        # Get feature type data
        for feat_type in results[model_name]['importance']['type'].unique():
            # Importance
            type_importance = results[model_name]['importance']\
                [results[model_name]['importance']['type'] == feat_type]['importance']
            
            # R²
            type_r2 = results[model_name]['regression']\
                [results[model_name]['regression']['type'] == feat_type]['r2_score']
            
            # MSE
            type_mse = results[model_name]['error']\
                [results[model_name]['error']['type'] == feat_type]['mse']
            
            # Add to comparison dataframes
            importance_by_type.append({
                'Loss Function': model_name.split('_')[-1],
                'Feature Type': feat_type,
                'Mean Importance': type_importance.mean()
            })
            
            r2_by_type.append({
                'Loss Function': model_name.split('_')[-1],
                'Feature Type': feat_type,
                'Mean R²': type_r2.mean()
            })
            
            mse_by_type.append({
                'Loss Function': model_name.split('_')[-1],
                'Feature Type': feat_type,
                'Mean MSE': type_mse.mean()
            })
    
    # Convert to dataframes
    importance_df = pd.DataFrame(importance_by_type)
    r2_df = pd.DataFrame(r2_by_type)
    mse_df = pd.DataFrame(mse_by_type)
    
    # Plot comparisons
    fig, axes = plt.subplots(1, 3, figsize=(12, 5))
    
    # Importance comparison
    sns.barplot(data=importance_df, x='Feature Type', y='Mean Importance', hue='Loss Function', ax=axes[0])
    axes[0].set_title(f'Mean Feature Importance by Loss Function (Latent Dim={latent_dim})')
    axes[0].set_ylabel('Mean Gradient Importance')
    axes[0].tick_params(axis='x', rotation=45)
    
    # R² comparison
    sns.barplot(data=r2_df, x='Feature Type', y='Mean R²', hue='Loss Function', ax=axes[1])
    axes[1].set_title(f'Mean R² Score by Loss Function (Latent Dim={latent_dim})')
    axes[1].set_ylabel('Mean R² Score')
    axes[1].tick_params(axis='x', rotation=45)
    
    # MSE comparison
    sns.barplot(data=mse_df, x='Feature Type', y='Mean MSE', hue='Loss Function', ax=axes[2])
    axes[2].set_title(f'Mean Reconstruction Error by Loss Function (Latent Dim={latent_dim})')
    axes[2].set_ylabel('Mean MSE')
    axes[2].tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Print summary table
    print("\nSummary of Performance by Loss Function and Feature Type:")
    summary = pd.merge(importance_df, r2_df, on=['Loss Function', 'Feature Type'])
    summary = pd.merge(summary, mse_df, on=['Loss Function', 'Feature Type'])
    print(summary.sort_values(['Loss Function', 'Feature Type']))
    
    return summary

# For pair-wise comparison
import itertools

def analyze_top_features_by_loss(results, latent_dim=50, top_n=20):
    """Compare top features across different loss functions.
    
    Args:
        results: Dictionary with experiment results
        latent_dim: Latent dimension to use
        top_n: Number of top features to show
        
    Returns:
        Visualizations of top features
    """
    # Filter models with the specified latent dimension
    loss_types = ['mse', 'poisson', 'nb']
    model_names = [f"AE_latent{latent_dim}_{loss}" for loss in loss_types]
    
    # Make sure all models exist in results
    model_names = [name for name in model_names if name in results]
    
    if not model_names:
        print(f"No models found with latent dimension {latent_dim}")
        return
    
    # Plot top features for each loss function
    fig, axes = plt.subplots(len(model_names), 1, figsize=(12, 4*len(model_names)))
    if len(model_names) == 1:
        axes = [axes]
    
    for i, model_name in enumerate(model_names):
        # Get importance dataframe
        importance_df = results[model_name]['importance']
        
        # Get top features
        top_features = importance_df.sort_values('importance', ascending=False).head(top_n)
        
        # Plot
        loss_type = model_name.split('_')[-1].upper()
        sns.barplot(data=top_features, x='feature', y='importance', hue='type', ax=axes[i])
        axes[i].set_title(f'Top {top_n} Important Features with {loss_type} Loss (Latent Dim={latent_dim})')
        axes[i].set_xlabel('Feature')
        axes[i].set_ylabel('Importance')
        axes[i].tick_params(axis='x', rotation=90)
        axes[i].legend(title='Feature Type')
    
    plt.tight_layout()
    plt.show()
    
    # Print overlap in top features between loss functions
    if len(model_names) > 1:
        print("\nAnalyzing overlap in top features between loss functions:")
        top_features_by_loss = {}
        
        for model_name in model_names:
            loss_type = model_name.split('_')[-1]
            importance_df = results[model_name]['importance']
            top_features = importance_df.sort_values('importance', ascending=False).head(top_n)['feature'].tolist()
            top_features_by_loss[loss_type] = set(top_features)
        
        for loss1, loss2 in itertools.combinations(loss_types, 2):
            if loss1 in top_features_by_loss and loss2 in top_features_by_loss:
                overlap = top_features_by_loss[loss1].intersection(top_features_by_loss[loss2])
                print(f"Overlap between {loss1.upper()} and {loss2.upper()}: {len(overlap)} features")
                if len(overlap) > 0:
                    print(f"Overlapping features: {', '.join(list(overlap)[:5])}" + 
                          ("..." if len(overlap) > 5 else ""))
    
    return top_features_by_loss

def plot_loss_curves(results):
    """Plot training loss curves for comparison.
    
    Args:
        results: Dictionary with experiment results
    """
    # Group models by latent dimension
    latent_dims = set(int(name.split('_')[1].replace('latent', '')) for name in results.keys())
    
    for latent_dim in latent_dims:
        # Get models for this latent dimension
        model_names = [name for name in results.keys() 
                       if int(name.split('_')[1].replace('latent', '')) == latent_dim]
        
        # Plot loss curves
        plt.figure(figsize=(10, 6))
        for model_name in model_names:
            history = results[model_name]['history']
            loss_type = model_name.split('_')[-1].upper()
            plt.plot(history['loss'], label=f'{loss_type} Loss')
        
        plt.title(f'Training Loss Curves (Latent Dim={latent_dim})')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.show()