import numpy as np
import torch
import pandas as pd
import torch
from scipy import sparse
import time
from tqdm import tqdm
import random
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression

# add src to path
import sys
import os
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.simulation_by_variables import TabularDataSimulator
from src.functions.train_larrp_unimodal import train_overcomplete_ae2

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--threshold', type=str, default="relative", help='threshold type: relative or absolute')
parser.add_argument('--n_samples', type=int, default=10000, help='number of samples to use for the computation')
args = parser.parse_args()

out_file = f"03_results/paper_results/unimodal/unimodal_param_sweep_{args.threshold}-threshold_seed_{args.seed}.csv"

# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

###
# generate data
###

# Create a tabular data simulator
tab_sim = TabularDataSimulator(
    n_samples=10000,
    n_hidden_variables=5,
    hidden_dist_type="poisson",
    data_dim=50,
    nonlinearity_level=1,
    nonlinearity_type="sigmoid",
    hidden_connectivity=0.4,
    data_sparsity=0.0,
    noise_variance=0.0,
    noise_mean=0.0,
    n_noise_components=1,
    random_seed=args.seed,
)
tab_data, tab_hidden, tab_hidden_orig = tab_sim.generate_data()
tab_data = torch.tensor(tab_data, dtype=torch.float32)

latent_dim = 20

###
# set up method parameters to test
###

# set the seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

class Args:
    def __init__(self):
        # latent
        self.latent_dim = 20

        # Training parameters
        self.batch_size = 128
        self.lr = 0.0001
        self.weight_decay = 2e-5
        self.dropout = 0.1
        self.epochs = 5000
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 1
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu
train_args = Args()

# method parameters
if args.threshold == "relative":
    method_hyperparameters = {
        "r_square_thresholds": [0.99, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7],
        "early_stopping": [50],
        "rank_reduction_frequencies": [5, 10, 20],
        "rank_reduction_thresholds": [0.0001, 0.001, 0.01, 0.1],
        "patiences": [5, 10, 20, 50, 100],
    }
elif args.threshold == "absolute":
    method_hyperparameters = {
        "r_square_thresholds": [0.005, 0.01, 0.05, 0.1, 0.15, 0.2],
        "early_stopping": [50],
        "rank_reduction_frequencies": [5, 10, 20],
        "rank_reduction_thresholds": [0.0001, 0.001, 0.01, 0.1],
        "patiences": [5, 10, 20, 50, 100],
    }
from itertools import product
method_combinations = list(product(*method_hyperparameters.values()))
print(f"Number of method combinations: {len(method_combinations)}")

###
# start training multiple configs
###
config_counter = 0
for r_square_threshold in method_hyperparameters["r_square_thresholds"]:
    for early_stopping in method_hyperparameters["early_stopping"]:
        for rank_reduction_frequency in method_hyperparameters["rank_reduction_frequencies"]:
            for rank_reduction_threshold in method_hyperparameters["rank_reduction_thresholds"]:
                for patience in method_hyperparameters["patiences"]:
                    config_counter += 1
                    print(f"### Run {config_counter}/{len(method_combinations)} ###")
                    model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae2(
                        tab_data, 
                        int(0.9 * args.n_samples),
                        latent_dim, 
                        DEVICE,
                        train_args,
                        epochs=train_args.epochs, 
                        lr=train_args.lr, 
                        batch_size=train_args.batch_size, 
                        ae_depth=train_args.ae_depth, 
                        ae_width=train_args.ae_width, 
                        dropout=train_args.dropout, 
                        wd=train_args.weight_decay,
                        early_stopping=early_stopping,
                        initial_rank_ratio=1.0,
                        rank_reduction_frequency=rank_reduction_frequency,
                        rank_reduction_threshold=rank_reduction_threshold,
                        warmup_epochs=early_stopping,
                        patience=patience,
                        min_rank=1,
                        r_square_threshold=r_square_threshold,
                        threshold_type=args.threshold,
                        verbose=False,
                        model_name=None
                    )

                    temp_df = pd.DataFrame(rank_history)
                    temp_df["final_ranks"] = rank_history["ranks"][-1]
                    
                    # add all the data and method parameters to the dataframe
                    temp_df["r_square_threshold"] = r_square_threshold
                    temp_df["early_stopping"] = early_stopping
                    temp_df["rank_reduction_frequency"] = rank_reduction_frequency
                    temp_df["rank_reduction_threshold"] = rank_reduction_threshold
                    temp_df["patience"] = patience
                    temp_df["seed"] = args.seed
                    temp_df["n_samples"] = args.n_samples
                    temp_df["threshold"] = args.threshold
                    temp_df["config"] = config_counter

                    try:
                        # evaluate the predictability of the true hidden variables
                        reg = LinearRegression()
                        reg.fit(reps.cpu().numpy(), tab_hidden_orig[:int(0.9 * args.n_samples)])
                        # compute the R**2 of the regression fit
                        r2 = reg.score(reps.cpu().numpy(), tab_hidden_orig[:int(0.9 * args.n_samples)])
                        temp_df["hidden_r2"] = r2
                    except Exception as e:
                        print(f"Error in regression: {e}")
                        temp_df["hidden_r2"] = np.nan

                    # if out_file exists, append to it, otherwise create it
                    if os.path.exists(out_file):
                        temp_df.to_csv(out_file, mode='a', header=False, index=False)
                    else:
                        temp_df.to_csv(out_file, mode='w', header=True, index=False)