import numpy as np
import torch
import pandas as pd
from scipy import sparse
import tqdm
import os
import argparse
import gc
import wandb
import torch.nn as nn
from datetime import datetime
import optuna  # Add Optuna import

# Add project root directory to Python path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from config import WANDB_ENTITY, get_project_root
from src.utils.logging import setup_wandb, log_wandb

# Set up argument Parser
parser = argparse.ArgumentParser(description='Hyperparameter sweep for Overcomplete Autoencoder')
# Data arguments
parser.add_argument('--data', type=str, default='mm_sim', help='data name. options are mm_sim, bonemarrow, brain')
parser.add_argument('--modality', type=str, default='rna', help='data modality. options are rna, atac, protein, rna-atac, rna-protein, atac-protein, all')
parser.add_argument('--n_batches', type=int, default=3, help='number of batches (10k samples each)')
parser.add_argument('--single_batch', type=bool, default=False, help='if True, only one batch (noise in data) is used for the computation')
parser.add_argument('--norm', type=bool, default=False, help='if True, the data is normalized before computing the metrics')
# Sweep arguments
parser.add_argument('--latent_dim', type=int, default=1000, help='fixed latent dimension for the sweep')
parser.add_argument('--seed', type=int, default=0, help='random seed for reproducibility')
parser.add_argument('--gpu', type=int, default=0, help='gpu id to use')
parser.add_argument('--output_dir', type=str, default='03_results/sweeps', help='Directory to save results')
parser.add_argument('--wandb', type=bool, default=False, help='Whether to use wandb for logging')
parser.add_argument('--wandb_entity', type=str, default=WANDB_ENTITY, help='Wandb entity name')
parser.add_argument('--wandb_project', type=str, default='intrinsic_dimensionality', help='Wandb project name')
parser.add_argument('--epochs', type=int, default=500, help='number of epochs to train the autoencoder')
parser.add_argument('--n_trials', type=int, default=100, help='Number of Optuna trials to run')
args = parser.parse_args()

# Define device
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# Define hyperparameter sweep configuration
sweep_config = {
    'ae_depth': [2, 3, 4, 6],
    'ae_width': [0.25, 0.5, 0.75, 1.0],
    'lr': (1e-5, 1e-3),  # Changed to range for log-uniform sampling
    'batch_size': [64, 128, 256, 512],
    'early_stopping': [10, 50],
    'wd': (1e-6, 1e-4),  # Changed to range for log-uniform sampling
    'dropout': [0.0, 0.1, 0.2],
}

# Define the Overcomplete Autoencoder model
class OvercompleteAE(torch.nn.Module):
    def __init__(self, input_dim, latent_dim, depth=2, width=0.5, dropout=0.0):
        super(OvercompleteAE, self).__init__()

        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        for i in range(depth):
            if i == (depth - 1):
                self.encoder.append(torch.nn.Linear(int(width * input_dim), latent_dim))
                self.decoder.append(torch.nn.Linear(int(width * input_dim), input_dim))
            else:
                if i == 0:
                    self.encoder.append(torch.nn.Linear(input_dim, int(width * input_dim)))
                    self.decoder.append(torch.nn.Linear(latent_dim, int(width * input_dim)))
                else:
                    self.encoder.append(torch.nn.Linear(int(width * input_dim), int(width * input_dim)))
                    self.decoder.append(torch.nn.Linear(int(width * input_dim), int(width * input_dim)))
                self.encoder.append(torch.nn.ReLU())
                self.decoder.append(torch.nn.ReLU())
                # Add dropout after ReLU activations if dropout rate > 0
                if dropout > 0.0:
                    self.encoder.append(torch.nn.Dropout(dropout))
                    self.decoder.append(torch.nn.Dropout(dropout))
    
    def encode(self, x):
        for layer in self.encoder:
            x = layer(x)
        return x
    
    def decode(self, x):
        for layer in self.decoder:
            x = layer(x)
        return x
    
    def forward(self, x):
        # encode
        x = self.encode(x)
        # decode
        x = self.decode(x)
        return x

def parallel_linear_regression(x, y, n_samples, n_samples_train, n_epochs=500, early_stopping=50):
    """Perform linear regression from latent space to original space to measure reconstruction ability"""
    y_mean = y[n_samples_train:n_samples].mean(dim=0)

    # loaders
    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x[:n_samples_train], y[:n_samples_train]), batch_size=128, shuffle=True)
    val_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x[n_samples_train:n_samples], y[n_samples_train:n_samples]), batch_size=128, shuffle=False)

    # set up a linear layer for regression
    linear = nn.Linear(x.shape[1], y.shape[1]).to(device)
    optimizer = torch.optim.Adam(linear.parameters(), lr=0.0001, weight_decay=0)
    loss_fn = nn.MSELoss()

    # train the linear layer
    val_losses = []
    pbar = tqdm.tqdm(range(n_epochs), desc="Training linear model")
    for epoch in pbar:
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            y_pred = linear(x_batch)
            loss = loss_fn(y_pred, y_batch)
            loss.backward()
            optimizer.step()
        val_loss = 0
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to(device), y_val.to(device)
            y_pred = linear(x_val)
            val_loss += loss_fn(y_pred, y_val).item()
        val_losses.append(val_loss / len(val_loader))
        pbar.set_postfix({'val loss': round(val_losses[-1], 4)})
        if epoch > early_stopping and min(val_losses[-early_stopping:]) > min(val_losses):
            print(f"Early stopping in linear regression at epoch {epoch}")
            break
    
    # Compute R^2 scores
    y_pred = linear(x[n_samples_train:n_samples].to(device)).cpu()
    y_pred = y_pred.detach()
    r_squares = 1 - (((y[n_samples_train:n_samples] - y_pred)**2).sum(0) / ((y[n_samples_train:n_samples] - y_mean)**2).sum(0))
    
    # Clean up
    del linear, optimizer
    torch.cuda.empty_cache()
    
    return r_squares, min(val_losses)

def train_and_evaluate(data, params):
    """Train autoencoder with given parameters and evaluate its performance"""
    n_samples = data.shape[0]
    n_samples_train = int(n_samples * 0.9)
    
    # Extract parameters
    latent_dim = args.latent_dim
    ae_depth = params['ae_depth']
    ae_width = params['ae_width']
    lr = params['lr']
    batch_size = params['batch_size']
    epochs = args.epochs
    early_stopping = params['early_stopping']
    wd = params['wd']  # Weight decay parameter
    dropout = params['dropout']  # Dropout parameter
    
    # Create model
    model = OvercompleteAE(data.shape[1], latent_dim, depth=ae_depth, width=ae_width, dropout=dropout).to(device)
    
    # Create optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    loss_fn = torch.nn.MSELoss()
    
    # Create data loader
    data_loader = torch.utils.data.DataLoader(data[:n_samples_train], batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(data[n_samples_train:], batch_size=batch_size, shuffle=False)
    
    # Train the model
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience_counter = 0
    
    pbar = tqdm.tqdm(range(epochs), desc="Training autoencoder")
    for epoch in pbar:
        # Training
        model.train()
        train_loss = 0
        for x in data_loader:
            x = x.to(device)
            optimizer.zero_grad()
            x_hat = model(x)
            loss = loss_fn(x_hat, x)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)
        train_loss /= len(data_loader.dataset)
        train_losses.append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x in val_loader:
                x = x.to(device)
                x_hat = model(x)
                loss = loss_fn(x_hat, x)
                val_loss += loss.item() * x.size(0)
        val_loss /= len(val_loader.dataset)
        val_losses.append(val_loss)
        
        # Log progress
        pbar.set_postfix({
            'train_loss': f'{train_loss:.6f}',
            'val_loss': f'{val_loss:.6f}'
        })
        
        # Log to wandb
        if args.wandb:
            log_wandb({
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss
            })
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping:
                print(f"Early stopping at epoch {epoch}")
                break
    
    # Compute latent representations and evaluate with linear regression
    model.eval()
    with torch.no_grad():
        reps = model.encode(data.to(device)).cpu()
    
    # Evaluate with linear regression
    r_squares, lin_val_loss = parallel_linear_regression(
        reps, data, n_samples, n_samples_train, n_epochs=500, early_stopping=50
    )
    
    # Remove NaN and Inf values
    r_squares = r_squares[torch.isfinite(r_squares)]
    
    # Compute mean R^2 and other metrics
    mean_r_square = r_squares.mean().item()
    median_r_square = r_squares.median().item()
    
    # Clean up
    del model, optimizer
    torch.cuda.empty_cache()
    gc.collect()
    
    # Return metrics
    results = {
        'final_train_loss': train_losses[-1],
        'final_val_loss': val_losses[-1],
        'best_val_loss': best_val_loss,
        'lin_val_loss': lin_val_loss,
        'mean_r_square': mean_r_square,
        'median_r_square': median_r_square,
        'epochs_trained': len(train_losses)
    }
    
    return results

def load_data():
    """Load data based on command-line arguments"""
    if args.data == 'mm_sim':
        data_dir = './01_data/mm_sim/'
        data = []
        for i in range(args.n_batches):
            if args.modality == 'rna':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            elif args.modality == 'atac':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            elif args.modality == 'protein':
                data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
            elif args.modality == 'rna-atac':
                temp_data = []
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
                temp_data = torch.cat(temp_data, dim=1)
                data.append(temp_data)
            elif args.modality == 'rna-protein':
                temp_data = []
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
                if args.norm:
                    temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                    temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
                temp_data = torch.cat(temp_data, dim=1)
                data.append(temp_data)
            elif args.modality == 'atac-protein':
                temp_data = []
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
                if args.norm:
                    temp_data[0] = temp_data[0] / torch.norm(temp_data[0].float(), dim=1, keepdim=True)
                    temp_data[1] = temp_data[1] / torch.norm(temp_data[1].float(), dim=1, keepdim=True)
                temp_data = torch.cat(temp_data, dim=1)
                data.append(temp_data)
            elif args.modality == 'all':
                temp_data = []
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
                temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
                if args.norm:
                    temp_data[0] = temp_data[0] / torch.norm(temp_data[0], dim=1, keepdim=True)
                    temp_data[1] = temp_data[1] / torch.norm(temp_data[1], dim=1, keepdim=True)
                    temp_data[2] = temp_data[2] / torch.norm(temp_data[2], dim=1, keepdim=True)
                temp_data = torch.cat(temp_data, dim=1)
                data.append(temp_data)
            else:
                raise ValueError("data modality not supported")
        data = torch.cat(data, dim=0)
        if args.single_batch:
            # load the batch info
            metadata = pd.concat([pd.read_csv(data_dir + f"causal_variables_batch_{i}.csv") for i in range(args.n_batches)])
            if args.modality == 'rna':
                indices = np.where(metadata['mrna_batch_effect'] == 0.0)[0]
            elif args.modality == 'protein':
                indices = np.where(metadata['prot_batch_effect'] == 0.0)[0]
            else:
                indices = np.arange(data.shape[0])
            print(f"using {len(indices)} samples from batch 0")
            data = data[indices, :]
    elif args.data == 'bonemarrow':
        data_dir = '../../data/singlecell/'
        import anndata as ad
        data_file = ad.read_h5ad(data_dir + "human_bonemarrow.h5ad")
        modality_switch = np.where(data_file.var['modality'] == 'ATAC')[0][0]
        if args.single_batch:
            # take a the donor-site combination that gives the most samples
            data_file = data_file[(data_file.obs['covariate_Site'] == 'site4') & (data_file.obs['DonorID'] == 19593)]
        if args.modality == 'rna':
            data = torch.tensor(np.asarray(data_file.layers['counts'][:, :modality_switch].todense()))
        elif args.modality == 'atac':
            data = torch.tensor(np.asarray(data_file.layers['counts'][:, modality_switch:].todense()))
        elif args.modality == 'rna-atac':
            if args.norm:
                data_temp_a = torch.tensor(np.asarray(data_file.layers['counts'][:, :modality_switch].todense()))
                # normalize
                data_temp_a = data_temp_a / torch.norm(data_temp_a, dim=1, keepdim=True)
                data_temp_b = torch.tensor(np.asarray(data_file.layers['counts'][:, modality_switch:].todense()))
                # normalize
                data_temp_b = data_temp_b / torch.norm(data_temp_b, dim=1, keepdim=True)
                data = torch.cat([data_temp_a, data_temp_b], dim=1)
            else:
                data = torch.tensor(np.asarray(data_file.layers['counts'].todense()))
        else:
            raise ValueError("data modality not supported")
        if not args.single_batch:
            data = data[:args.n_batches*10000, :]
        data = data.cpu()
    
    return data.float()

def objective(trial, data):
    """Optuna objective function for hyperparameter optimization"""
    # Sample hyperparameters
    ae_depth = trial.suggest_categorical('ae_depth', sweep_config['ae_depth'])
    ae_width = trial.suggest_categorical('ae_width', sweep_config['ae_width'])
    lr = trial.suggest_float('lr', sweep_config['lr'][0], sweep_config['lr'][1], log=True)  # Log-uniform sampling
    batch_size = trial.suggest_categorical('batch_size', sweep_config['batch_size'])
    early_stopping = trial.suggest_categorical('early_stopping', sweep_config['early_stopping'])
    wd = trial.suggest_float('wd', sweep_config['wd'][0], sweep_config['wd'][1], log=True)  # Log-uniform sampling
    dropout = trial.suggest_categorical('dropout', sweep_config['dropout'])
    
    # Create parameter dictionary
    params = {
        'ae_depth': ae_depth,
        'ae_width': ae_width,
        'lr': lr,
        'batch_size': batch_size,
        'early_stopping': early_stopping,
        'wd': wd,
        'dropout': dropout
    }
    
    # Update wandb config if needed
    if args.wandb:
        wandb.config.update(params, allow_val_change=True)
    
    # Train and evaluate model
    try:
        trial_results = train_and_evaluate(data, params)
        
        # Add parameters to results
        trial_results.update(params)
        
        # Log to wandb
        if args.wandb:
            log_wandb(trial_results, mode="eval")
        
        # Return the metrics to optimize (want to maximize r_square and minimize val_loss)
        return trial_results['mean_r_square'], trial_results['best_val_loss']
    
    except Exception as e:
        print(f"Error in trial: {e}")
        # Return very poor values for both objectives
        return -1.0, float('inf')

def main():
    # Load data
    print("Loading data...")
    data = load_data()
    print(f"Loaded data with shape: {data.shape}")
    
    # Create output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(args.output_dir, f"probing_sweep_{args.data}_{args.modality}_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize wandb if needed
    if args.wandb:
        hyperparams = {
            'data': args.data,
            'modality': args.modality,
            'n_batches': args.n_batches,
            'single_batch': args.single_batch,
            'norm': args.norm,
            'latent_dim': args.latent_dim,
            'seed': args.seed,
            'n_trials': args.n_trials
        }
        run_name = f"probing_sweep_{args.data}_{args.modality}_{timestamp}"
        setup_wandb(run_name, hyperparams, args.wandb_project, args.wandb_entity)
    
    # Create an Optuna study for multi-objective optimization
    study_name = f"probing_sweep_{args.data}_{args.modality}"
    storage_name = f"sqlite:///{output_dir}/optuna.db"
    study = optuna.create_study(
        study_name=study_name,
        storage=storage_name,
        directions=["maximize", "minimize"],  # Maximize R², minimize loss
        load_if_exists=True
    )
    
    print(f"Running Optuna optimization with {args.n_trials} trials")
    
    # Run the optimization
    study.optimize(lambda trial: objective(trial, data), n_trials=args.n_trials)
    
    # Get the best trials (Pareto front)
    best_trials = study.best_trials
    
    print("\n=== Optimization Complete ===")
    print(f"Found {len(best_trials)} Pareto-optimal solutions")
    
    # Select a balanced solution from the Pareto front
    # You could implement different selection strategies here
    best_balanced_trial = None
    best_balance_score = float('-inf')
    
    for trial in best_trials:
        # Normalize the objectives to [0,1] range across all best trials
        r2_values = [t.values[0] for t in best_trials if t.values[0] > 0]
        loss_values = [t.values[1] for t in best_trials if t.values[1] < float('inf')]
        
        if not r2_values or not loss_values:
            continue
            
        max_r2 = max(r2_values)
        min_r2 = min(r2_values)
        max_loss = max(loss_values)
        min_loss = min(loss_values)
        
        # Avoid division by zero
        r2_range = max_r2 - min_r2 if max_r2 > min_r2 else 1.0
        loss_range = max_loss - min_loss if max_loss > min_loss else 1.0
        
        # Normalize to [0,1], higher is better for both
        norm_r2 = (trial.values[0] - min_r2) / r2_range if trial.values[0] > 0 else 0
        norm_loss = 1.0 - (trial.values[1] - min_loss) / loss_range if trial.values[1] < float('inf') else 0
        
        # Balance score - geometric mean gives a good balance
        balance_score = (norm_r2 * norm_loss) ** 0.5
        
        if balance_score > best_balance_score:
            best_balance_score = balance_score
            best_balanced_trial = trial
    
    if best_balanced_trial:
        print(f"Selected balanced solution:")
        print(f"Trial: {best_balanced_trial.number}")
        print(f"Parameters: {best_balanced_trial.params}")
        print(f"Mean R² score: {best_balanced_trial.values[0]}")
        print(f"Validation loss: {best_balanced_trial.values[1]}")
        
        # Save the selected balanced solution
        with open(os.path.join(output_dir, "best_balanced_config.json"), 'w') as f:
            import json
            json.dump(best_balanced_trial.params, f, indent=2)
    
    # Save all Pareto-optimal solutions
    pareto_results = []
    for i, trial in enumerate(best_trials):
        result = {
            'trial_number': trial.number,
            'mean_r_square': trial.values[0],
            'best_val_loss': trial.values[1],
            **trial.params
        }
        pareto_results.append(result)
    
    pd.DataFrame(pareto_results).to_csv(os.path.join(output_dir, "pareto_optimal_results.csv"), index=False)
    
    # Save all trial results
    trials_df = study.trials_dataframe()
    trials_df.to_csv(os.path.join(output_dir, "optuna_trials.csv"), index=False)
    
    # Close wandb
    if args.wandb:
        wandb.finish()

if __name__ == "__main__":
    main()
