import numpy as np
import torch
import pandas as pd
import os
import random
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
import time
import gc
import torch.nn as nn
import torch.nn.functional as F
from scipy import sparse

# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.loading import load_data, MMSimData
from src.functions.train_larrp_multimodal import train_overcomplete_ae_with_pretrained, compute_classification, post_train_multimodal_ae

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--data', type=str, default='mm_sim', help='data name. options are mm_sim, bonemarrow, brain')
parser.add_argument('--modality', type=str, default='rna', help='data modality. options are rna, atac, protein, rna-atac, rna-protein, atac-protein, all')
parser.add_argument('--stage', type=str, default='noisy', help='stage of the data in the data generation process. only valid for mm_sim. options are noisy, raw, processed.')
parser.add_argument('--n_batches', type=int, default=1, help='number of batches (10k samples each)')
parser.add_argument('--seed', type=int, default=0, help='random seed for reproducibility')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--paired', action='store_true', help='whether the data is paired (default: False)')
parser.add_argument('--threshold', type=str, default='absolute', choices=['relative', 'absolute'], help='whether to use relative or absolute R2 thresholds (default: relative)')
args = parser.parse_args()

# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

run_name = f"larrp_mm-omics-sim_data-{args.data}_mod-{args.modality}_stage-{args.stage}_n-{args.n_batches}_rseed-{args.seed}_paired-{args.paired}_threshold-{args.threshold}_post-training"
out_file = f"03_results/reports/{run_name}.csv"

# if the file already exists, remove it
if os.path.exists(out_file):
    print(f"File {out_file} already exists, removing it.")
    os.remove(out_file)

###
# load data using mm_sim setup
###
data, feature_counts, modality_indices = load_data(args.data, args.modality, args.paired, args.n_batches, args.stage)

# split the modalities into lists for the model input
split_data = []
for i, mod in enumerate(args.modality.split('-')):
    start_idx = sum(feature_counts[m] for m in args.modality.split('-')[:i])
    end_idx = sum(feature_counts[m] for m in args.modality.split('-')[:i+1])
    split_data.append(data[:, start_idx:end_idx].float())

n_samples = data.shape[0]
print(f"loaded data '{args.data}' with shape {data.shape}")
print(f"split into {len(split_data)} modalities with shapes: {[d.shape for d in split_data]}")

# normalize data to have values between 0 and 10 (similar to parametric simulation)
# try to do without for now
#for i in range(len(split_data)):
#    split_data[i] = 10 * (split_data[i] - torch.min(split_data[i])) / (torch.max(split_data[i]) - torch.min(split_data[i]))

# labels
# load causal_variables_batch_{batches}.csv and use every column as a label
# for the following columns, calculate adjusted Rand index and silhouette score on the latents:
# cell_type, stress_level, cell_cycle, mrna_batch_effect, prot_batch_effect
# for the rest, calculate the predictability as R^2 of a linear regression from the latents
ground_truth = []
for i in range(args.n_batches):
    ground_truth.append(pd.read_csv(f"01_data/mm_sim/causal_variables_batch_{i}.csv"))
ground_truth = pd.concat(ground_truth, ignore_index=True)
#labels = np.random.rand(n_samples, 3)  # 3 dummy label columns to match parametric simulation structure

del data
gc.collect()  # Clear memory after loading data

# set up data in the right format for training (similar to parametric simulation)
if not args.paired:
    # For unpaired data, create zero-padded versions
    data_list = []
    for i, mod_data in enumerate(split_data):
        # Create a version where this modality has data and others are zeros
        mod_list = []
        for j in range(len(split_data)):
            if i == j:
                mod_list.append(mod_data)
            else:
                mod_list.append(torch.zeros_like(split_data[j]))
        data_list.extend(mod_list)
    
    # Concatenate all versions
    y1 = torch.cat([data_list[0], data_list[len(split_data)]], dim=0) if len(split_data) >= 2 else data_list[0]
    y2 = torch.cat([data_list[1], data_list[len(split_data)+1]], dim=0) if len(split_data) >= 2 else torch.zeros_like(data_list[0])
    
    # Duplicate ground truth to match duplicated data
    ground_truth = pd.concat([ground_truth, ground_truth], ignore_index=True)
    data = [y1, y2]
else:
    # For paired data, use modalities directly
    data = split_data

data = [torch.FloatTensor(d) for d in data]

# Update n_samples to reflect actual data size (may be doubled for unpaired)
n_samples = data[0].shape[0]
print(f"Final data shape after processing: {[d.shape for d in data]}")
print(f"Ground truth shape: {ground_truth.shape}")

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
random.seed(args.seed)

# method parameters (from parametric simulation)
if args.threshold == "relative":
    method_hyperparameters = {
        "r_square_thresholds": [0.99, 0.9, 0.8],
        "early_stopping": [50],
        "rank_reduction_frequencies": [5],
        "rank_reduction_thresholds": [0.01],
        "patiences": [5],
    }
elif args.threshold == "absolute":
    method_hyperparameters = {
        "r_square_thresholds": [0.005, 0.05, 0.1],
        "early_stopping": [50],
        "rank_reduction_frequencies": [10],
        "rank_reduction_thresholds": [0.01],
        "patiences": [10],
    }

from itertools import product
method_combinations = list(product(*method_hyperparameters.values()))
print(f"Number of method combinations: {len(method_combinations)}")

class Args:
    def __init__(self):
        # latent
        self.latent_dim = 500

        # Training parameters
        self.batch_size = 1024
        self.lr = 1e-6  # Lower learning rate for post-training
        self.weight_decay = 2e-5
        self.dropout = 0.1
        self.epochs = 500  # Post-training epochs
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 0.5
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.num_workers = 8
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu

train_args = Args()

###
# start training multiple configs
###
config_counter = 0
for r_square_threshold in method_hyperparameters["r_square_thresholds"]:
    for early_stopping in method_hyperparameters["early_stopping"]:
        for rank_reduction_frequency in method_hyperparameters["rank_reduction_frequencies"]:
            for rank_reduction_threshold in method_hyperparameters["rank_reduction_thresholds"]:
                for patience in method_hyperparameters["patiences"]:
                    config_counter += 1
                    print(f"### Run {config_counter}/{len(method_combinations)} ###")
                    
                    # Load the original pretrained model name used during training
                    original_pretrained_name = f"larrp_mm-omics-sim_data-{args.data}_mod-{args.modality}_stage-{args.stage}_n-{args.n_batches}_rseed-{args.seed}_paired-{args.paired}"
                    
                    # The trained model name from the original training script
                    original_run_name = f"larrp_mm-omics-sim_data-{args.data}_mod-{args.modality}_stage-{args.stage}_n-{args.n_batches}_rseed-{args.seed}_paired-{args.paired}_threshold-{args.threshold}"
                    trained_model_name = original_run_name + str(config_counter)
                    
                    # Do post-training on the fully trained model
                    model, post_train_losses, post_val_losses = post_train_multimodal_ae(
                        data,
                        int(0.9 * n_samples),
                        train_args.latent_dim,
                        DEVICE,
                        train_args,
                        epochs=500,  # Post-training for up to 500 epochs
                        early_stopping=50,
                        lr=1e-6,  # Use lower learning rate for post-training
                        batch_size=train_args.batch_size,
                        ae_depth=train_args.ae_depth,
                        ae_width=train_args.ae_width,
                        dropout=train_args.dropout,
                        wd=train_args.weight_decay,
                        patience=10,
                        verbose=True,
                        model_name=run_name + str(config_counter),
                        pretrained_name=trained_model_name,  # Load the fully trained model
                        recon_loss_balancing=False,
                        paired=args.paired
                    )
                    
                    # Extract representations from the post-trained model
                    model.eval()
                    reps = []
                    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
                    
                    # Calculate latent representations in batches
                    n_samples_total = data[0].shape[0]
                    reps = [torch.empty((n_samples_total, final_ranks[i]), device=DEVICE) for i in range(len(final_ranks))]
                    
                    with torch.no_grad():
                        for i in range(0, n_samples_total, train_args.batch_size):
                            end_i = min(i + train_args.batch_size, n_samples_total)
                            batch_data = [d[i:end_i].to(DEVICE) for d in data]
                            _, h_encoded = model(batch_data)
                            
                            # Extract representations (h_encoded is a tuple: (h_shared, h_specific))
                            h_shared, h_specific = h_encoded
                            reps[0][i:end_i] = h_shared.cpu()
                            for j, h_spec in enumerate(h_specific):
                                if j + 1 < len(reps):
                                    reps[j + 1][i:end_i] = h_spec.cpu()
                    
                    print(f"Post-training done. Final ranks: {final_ranks}")
                    
                    # Create a mock rank_history for compatibility with analysis code
                    rank_history = {
                        'ranks': [', '.join(str(r) for r in final_ranks)],
                        'total_rank': [sum(final_ranks)],
                        'epoch': [len(post_train_losses)],
                        'loss': [post_train_losses[-1]],
                        'val_loss': [post_val_losses[-1]]
                    }
                    
                    # Add rsquare placeholders (not meaningful after post-training, but needed for compatibility)
                    for j in range(len(data)):
                        rank_history[f'rsquare {j}'] = [1.0]  # Placeholder value
                    del model
                    gc.collect()
                    torch.cuda.empty_cache()

                    temp_df = pd.DataFrame(rank_history)

                    # add all the data and method parameters to the dataframe
                    temp_df["r_square_threshold"] = r_square_threshold
                    temp_df["early_stopping"] = early_stopping
                    temp_df["rank_reduction_frequency"] = rank_reduction_frequency
                    temp_df["rank_reduction_threshold"] = rank_reduction_threshold
                    temp_df["patience"] = patience

                    # add the most important final metrics
                    valid_rsquares = [rank_history[f"rsquare {j}"] for j in range(len(data))]
                    temp_df["r_squares_init"] = ', '.join([str(valid_rsquares[0][0]), str(valid_rsquares[1][0])])
                    temp_df["r_squares_final"] = ', '.join([str(valid_rsquares[0][-1]), str(valid_rsquares[1][-1])])
                    temp_df["final_ranks"] = rank_history["ranks"][-1]
                    temp_df["config"] = config_counter
                    temp_df["post_training_epochs"] = len(post_train_losses)
                    temp_df["final_train_loss"] = post_train_losses[-1]
                    temp_df["final_val_loss"] = post_val_losses[-1]

                    # Define which columns get which analysis type
                    classification_cols = ['cell_type', 'stress_level', 'cell_cycle', 'mrna_batch_effect', 'prot_batch_effect']
                    regression_cols = ['transcription_activity', 'damage_prob', 'ribosome_rate', 'free_ribosomes', 
                                     'tRNA_availability', 'proteasome_activity']
                    
                    # Calculate classification accuracy and silhouette score for discrete variables
                    for col in classification_cols:
                        if col in ground_truth.columns:
                            col_labels = ground_truth[col].values
                            
                            accs = []
                            sils = []
                            for j in range(len(reps)):
                                # Classification accuracy
                                acc = compute_classification(reps[j].cpu().numpy(), col_labels)
                                accs.append(acc)
                                
                                # Silhouette score
                                unique_labels = np.unique(col_labels)
                                n_unique = len(unique_labels)
                                n_samples_check = len(col_labels)
                                
                                if 2 <= n_unique <= n_samples_check - 1:
                                    sil = silhouette_score(reps[j].cpu().numpy(), col_labels)
                                else:
                                    sil = np.nan  # Cannot compute silhouette score
                                
                                sils.append(sil)
                            
                            temp_df[f"{col}_classification_accuracy"] = ', '.join([str(a) for a in accs])
                            temp_df[f"{col}_silhouette_score"] = ', '.join([str(s) for s in sils])
                            print(f"{col} Classification: {[f'{a:.4f}' for a in accs]}")
                            print(f"{col} Silhouette: {[f'{s:.4f}' if not np.isnan(s) else 'nan' for s in sils]}")
                    
                    # Calculate predictability (R²) for continuous variables
                    for col in regression_cols:
                        if col in ground_truth.columns:
                            col_labels = ground_truth[col].values
                            
                            preds = []
                            for j in range(len(reps)):
                                reg = LinearRegression()
                                reg.fit(reps[j].cpu().numpy(), col_labels)
                                # compute the R²  of the regression fit
                                r2 = reg.score(reps[j].cpu().numpy(), col_labels)
                                preds.append(r2)
                            
                            temp_df[f"{col}_prediction_r2"] = ', '.join([str(p) for p in preds])
                            print(f"{col} Prediction R²: {[f'{p:.4f}' for p in preds]}")

                    # if out_file exists, append to it, otherwise create it
                    if os.path.exists(out_file):
                        temp_df.to_csv(out_file, mode='a', header=False, index=False)
                    else:
                        temp_df.to_csv(out_file, mode='w', header=True, index=False)

print(f"All experiments completed. Results saved to {out_file}")
