import numpy as np
import torch
import pandas as pd
import torch
from scipy import sparse
import time
from tqdm import tqdm
import random
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression

# add src to path
import sys
import os
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.simulation_by_variables import TabularDataSimulator
from src.functions.train_larrp_unimodal import train_overcomplete_ae2

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--exp', type=str, default="samples", help='experiment type: samples, data, noise')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--threshold', type=str, default="absolute", help='threshold type: relative or absolute')
args = parser.parse_args()

out_file = f"03_results/paper_results/unimodal/unimodal_param_robustness_exp-{args.exp}_seed-{args.seed}.csv"

# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

###
# set up method parameters to test
###

class Args:
    def __init__(self):
        # latent
        self.latent_dim = 20

        # Training parameters
        self.batch_size = 128
        self.lr = 0.0001
        self.weight_decay = 2e-5
        self.dropout = 0.1
        self.epochs = 5000
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 1
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu
train_args = Args()

# method parameters
if args.threshold == "relative":
    method_hyperparameters = {
        "r_square_thresholds": 0.9,
        "early_stopping": 50,
        "rank_reduction_frequencies": 10,
        "rank_reduction_thresholds": 0.001,
        "patiences": 10,
    }
elif args.threshold == "absolute":
    method_hyperparameters = {
        "r_square_thresholds": 0.05,
        "early_stopping": 50,
        "rank_reduction_frequencies": 10,
        "rank_reduction_thresholds": 0.01,
        "patiences": 10,
    }

###
# set up data parameters to test
###

default_params = {
    "n_samples": 10000,
    "n_hidden_variables": 5,
    "hidden_dist_type": "poisson",
    "data_dim": 50,
    "latent_dim": 20,
    "nonlinearity_level": 1,
    "nonlinearity_type": "sigmoid",
    "hidden_connectivity": 0.4,
    "data_sparsity": 0.0,
    "noise_variance": 0.0,
    "noise_mean": 0.0,
    "n_noise_components": 1
}

exp_parameters = {}
if args.exp == "samples":
    exp_parameters = {
        #"n_samples": [100, 1000, 10000, 100000, 1000000]
        "n_samples": [500, 2500, 5000, 7500, 50000]
    }
if args.exp == "data":
    exp_parameters = {
        #"data_dim": [10, 25, 50, 100, 500],
        "data_dim": [15, 20],
        #"hidden_dist_type": ["poisson", "gaussian", "binomial", "beta", "gumbel", "uniform", "weibull"],
        #"hidden_connectivity": [0.2, 0.4, 0.6, 0.8],
        #"hidden_connectivity": [0.1, 0.7, 0.9],
        "hidden_connectivity": [1.0],
        #"nonlinearity_level": [0, 1, 2],
        #"nonlinearity_type": ["sigmoid", "trigonometric", "relu"]
    }
if args.exp == "noise":
    exp_parameters = {
        #"noise_variance": [0.0, 0.1, 0.5, 1.0, 2.0], # for this I need to change it to signal-to-noise ratio
        #"noise_variance": [0.2, 0.3, 0.4],
        "noise_variance": [0.0001, 0.01],
        #"data_sparsity": [0.0, 0.2, 0.4, 0.6, 0.8]
        #"data_sparsity": [0.01, 0.05, 0.1, 0.15]
    }

# Generate all combinations of robustness parameters
# check 
if not os.path.exists(out_file):
    robustness_combinations = [(None, None)] # first leads to default
else:
    robustness_combinations = []
    print(f"Output file {out_file} already exists. Appending new results only.")
for param_name, param_values in exp_parameters.items():
    for param_value in param_values:
        robustness_combinations.append((param_name, param_value))

for param_name, param_value in robustness_combinations:
    # Create modified data hyperparameters
    data_hyperparams = default_params.copy()
    if param_name is not None:
        data_hyperparams[param_name] = param_value

    print(f"\n=== Testing {param_name} = {param_value} ===")

    # Create a tabular data simulator
    tab_sim = TabularDataSimulator(
        n_samples=data_hyperparams["n_samples"],
        n_hidden_variables=data_hyperparams["n_hidden_variables"],
        hidden_dist_type=data_hyperparams["hidden_dist_type"],
        data_dim=data_hyperparams["data_dim"],
        nonlinearity_level=data_hyperparams["nonlinearity_level"],
        nonlinearity_type=data_hyperparams["nonlinearity_type"],
        hidden_connectivity=data_hyperparams["hidden_connectivity"],
        data_sparsity=data_hyperparams["data_sparsity"],
        noise_variance=data_hyperparams["noise_variance"],
        noise_mean=data_hyperparams["noise_mean"],
        n_noise_components=data_hyperparams["n_noise_components"],
        random_seed=0,
    )
    tab_data, tab_hidden, tab_hidden_orig = tab_sim.generate_data()
    tab_data = torch.tensor(tab_data, dtype=torch.float32)

    # set the seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae2(
        tab_data, 
        int(0.9 * data_hyperparams["n_samples"]),
        train_args.latent_dim,
        DEVICE,
        train_args,
        epochs=train_args.epochs, 
        lr=train_args.lr, 
        batch_size=train_args.batch_size, 
        ae_depth=train_args.ae_depth, 
        ae_width=train_args.ae_width, 
        dropout=train_args.dropout, 
        wd=train_args.weight_decay,
        early_stopping=method_hyperparameters["early_stopping"],
        initial_rank_ratio=1.0,
        rank_reduction_frequency=method_hyperparameters["rank_reduction_frequencies"],
        rank_reduction_threshold=method_hyperparameters["rank_reduction_thresholds"],
        warmup_epochs=method_hyperparameters["early_stopping"],
        patience=method_hyperparameters["patiences"],
        min_rank=1,
        r_square_threshold=method_hyperparameters["r_square_thresholds"],
        threshold_type=args.threshold,
        verbose=False,
        model_name=None
    )

    temp_df = pd.DataFrame(rank_history)
    temp_df["final_ranks"] = rank_history["ranks"][-1]           

    try:
        # evaluate the predictability of the true hidden variables
        reg = LinearRegression()
        reg.fit(reps.cpu().numpy(), tab_hidden_orig[:int(0.9 * data_hyperparams["n_samples"])])
        # compute the R**2 of the regression fit
        r2 = reg.score(reps.cpu().numpy(), tab_hidden_orig[:int(0.9 * data_hyperparams["n_samples"])])
        temp_df["hidden_r2"] = r2
    except Exception as e:
        print(f"Error in regression: {e}")
        temp_df["hidden_r2"] = np.nan

    temp_df["param_name"] = param_name
    temp_df["param_value"] = param_value
    temp_df["seed"] = args.seed

    # if out_file exists, append to it, otherwise create it
    if os.path.exists(out_file):
        temp_df.to_csv(out_file, mode='a', header=False, index=False)
    else:
        temp_df.to_csv(out_file, mode='w', header=True, index=False)