import numpy as np
import torch
import pandas as pd
import torch
from scipy import sparse
import time
from tqdm import tqdm
import random
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
import gc

# add src to path
import sys
import os
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.functions.train_larrp_unimodal import train_overcomplete_ae2_with_pretrained

import argparse
parser = argparse.ArgumentParser(description='Compute unimodal omics ID estimation with different modalities')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--modality', type=str, default='rna', help='data modality. options are rna, atac, protein, rna-atac, rna-protein, atac-protein, all')
parser.add_argument('--stage', type=str, default='noisy', help='stage of the data in the data generation process. options are noisy, raw, processed.')
parser.add_argument('--n_batches', type=int, default=3, help='number of batches (10k samples each)')
parser.add_argument('--threshold_type', type=str, default='absolute', choices=['relative', 'absolute'], help='threshold type for rank reduction')
parser.add_argument('--seed', type=int, default=0, help='random seed')
args = parser.parse_args()

out_file = f"03_results/paper_results/unimodal_omics/unimodal_omics_{args.modality}_{args.stage}_n{args.n_batches}batches_thresh-{args.threshold_type}_seed-{args.seed}_additionalThresholds2.csv"

# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

###
# set up method parameters to test
###

class Args:
    def __init__(self):
        # latent
        self.latent_dim = 1000

        # Training parameters
        self.batch_size = 512
        self.lr = 1e-5
        #self.lr = 1e-4
        self.weight_decay = 2e-5
        self.dropout = 0.1
        self.epochs = 5000
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 0.5
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu
train_args = Args()

# method parameters
if args.threshold_type == "relative":
    method_hyperparameters = {
        "r_square_thresholds": [0.95, 0.8],
        "early_stopping": 50,
        "rank_reduction_frequencies": 10,
        "rank_reduction_thresholds": 0.001,
        "patiences": 10,
    }
elif args.threshold_type == "absolute":
    method_hyperparameters = {
        #"r_square_thresholds": [0.005, 0.05, 0.1],
        "r_square_thresholds": [0.001, 0.005, 0.05, 0.01, 0.1],
        "early_stopping": 50,
        "rank_reduction_frequencies": 10,
        "rank_reduction_thresholds": [0.01],
        "patiences": 10
    }
if args.modality == "rna":
    method_hyperparameters["metric"] = "ExVarScore"
else:
    method_hyperparameters["metric"] = "R2"

###
# load data
###

data_dir = './01_data/mm_sim/'
data = []
for i in range(args.n_batches):
    if args.modality == 'rna':
        if args.stage == 'noisy':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
        elif args.stage == 'raw':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_transcription_batch_{i}.npz").toarray()))
        elif args.stage == 'processed':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_transcription_batch_{i}.npz").toarray()))
        else:
            raise ValueError("stage not supported for rna modality")
    elif args.modality == 'atac':
        if args.stage == 'noisy':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
        elif args.stage == 'raw':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_peaks_batch_{i}.npz").toarray()))
        elif args.stage == 'processed':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_peaks_batch_{i}.npz").toarray()))
        else:
            raise ValueError("stage not supported for atac modality")
    elif args.modality == 'protein':
        if args.stage == 'noisy':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
        elif args.stage == 'raw':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_prot_counts_batch_{i}.npz").toarray()))
        elif args.stage == 'processed':
            data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_prot_counts_batch_{i}.npz").toarray()))
        else:
            raise ValueError("stage not supported for protein modality")
    elif args.modality == 'rna-atac':
        temp_data = []
        if args.stage == 'noisy':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
        elif args.stage == 'raw':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_peaks_batch_{i}.npz").toarray()))
        elif args.stage == 'processed':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_peaks_batch_{i}.npz").toarray()))
        else:
            raise ValueError("stage not supported for rna-atac modality")
        temp_data = torch.cat(temp_data, dim=1)
        data.append(temp_data)
    elif args.modality == 'rna-protein':
        temp_data = []
        if args.stage == 'noisy':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
        elif args.stage == 'raw':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_prot_counts_batch_{i}.npz").toarray()))
        elif args.stage == 'processed':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_prot_counts_batch_{i}.npz").toarray()))
        else:
            raise ValueError("stage not supported for rna-protein modality")
        temp_data = torch.cat(temp_data, dim=1)
        data.append(temp_data)
    elif args.modality == 'atac-protein':
        temp_data = []
        if args.stage == 'noisy':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
        elif args.stage == 'raw':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_prot_counts_batch_{i}.npz").toarray()))
        elif args.stage == 'processed':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_prot_counts_batch_{i}.npz").toarray()))
        else:
            raise ValueError("stage not supported for atac-protein modality")
        temp_data = torch.cat(temp_data, dim=1)
        data.append(temp_data)
    elif args.modality == 'all':
        temp_data = []
        if args.stage == 'noisy':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"observed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"prot_counts_batch_{i}.npz").toarray()))
        elif args.stage == 'raw':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"raw_prot_counts_batch_{i}.npz").toarray()))
        elif args.stage == 'processed':
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_transcription_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_peaks_batch_{i}.npz").toarray()))
            temp_data.append(torch.tensor(sparse.load_npz(data_dir+f"processed_prot_counts_batch_{i}.npz").toarray()))
        else:
            raise ValueError("stage not supported for all modality")
        temp_data = torch.cat(temp_data, dim=1)
        data.append(temp_data)
    else:
        raise ValueError("data modality not supported")

data = torch.cat(data, dim=0)

data = data.float()
n_samples = data.shape[0]
print(f"loaded data with modality '{args.modality}', stage '{args.stage}' with shape {data.shape}")

###
# run training
###

print(f"\n=== Training on {args.modality} modality, {args.stage} stage ===")

# Loop over different r_square_thresholds
for r_square_threshold in method_hyperparameters["r_square_thresholds"]:
    for rrt in method_hyperparameters["rank_reduction_thresholds"]:
        print(f"\n--- Testing r_square_threshold: {r_square_threshold}, rank_reduction_threshold: {rrt} ---")
    
        # set the seed
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
        random.seed(args.seed)

        # Create pretrained model name for this configuration
        pretrained_name = f"unimodal_omics_{args.modality}_{args.stage}_n{args.n_batches}batches_seed-{args.seed}"
        
        model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae2_with_pretrained(
            data, 
            int(0.9 * n_samples),
            train_args.latent_dim,
            DEVICE,
            train_args,
            epochs=train_args.epochs, 
            lr=train_args.lr, 
            batch_size=train_args.batch_size, 
            ae_depth=train_args.ae_depth, 
            ae_width=train_args.ae_width, 
            dropout=train_args.dropout, 
            wd=train_args.weight_decay,
            early_stopping=method_hyperparameters["early_stopping"],
            initial_rank_ratio=1.0,
            rank_reduction_frequency=method_hyperparameters["rank_reduction_frequencies"],
            rank_reduction_threshold=rrt,
            warmup_epochs=method_hyperparameters["early_stopping"],
            patience=method_hyperparameters["patiences"],
            distortion_metric=method_hyperparameters["metric"],
            min_rank=1,
            r_square_threshold=r_square_threshold,
            threshold_type=args.threshold_type,
            verbose=True,
            pretrained_name=pretrained_name
        )
        del model
        gc.collect()
        torch.cuda.empty_cache()

        temp_df = pd.DataFrame(rank_history)
        temp_df["final_ranks"] = rank_history["ranks"][-1]           

        # Since we don't have hidden variables like in the simulated data,
        # we'll evaluate reconstruction quality instead
        temp_df["reconstruction_r2"] = r_squares

        temp_df["modality"] = args.modality
        temp_df["stage"] = args.stage
        temp_df["n_batches"] = args.n_batches
        temp_df["threshold_type"] = args.threshold_type
        temp_df["r_square_threshold"] = r_square_threshold
        temp_df["rank_reduction_threshold"] = rrt
        temp_df["seed"] = args.seed

        # Create output directory if it doesn't exist
        os.makedirs(os.path.dirname(out_file), exist_ok=True)

        # if out_file exists, append to it, otherwise create it
        if os.path.exists(out_file):
            temp_df.to_csv(out_file, mode='a', header=False, index=False)
        else:
            temp_df.to_csv(out_file, mode='w', header=True, index=False)

        print(f"Results for r_square_threshold={r_square_threshold}:")
        print(f"  Final ranks: {rank_history['ranks'][-1]}")
        print(f"  Final reconstruction R²: {r_squares}")
        print(f"  Final loss: {train_loss}")

print(f"\nAll results saved to {out_file}")