import numpy as np
import torch
import pandas as pd
import torch
from scipy import sparse
import time
from tqdm import tqdm
import random
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression

# add src to path
import sys
import os
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.simulation_by_variables import TabularDataSimulator
from src.functions.train_larrp_unimodal import train_overcomplete_ae3

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
args = parser.parse_args()

out_file = f"03_results/paper_results/unimodal/unimodal_param_robustness_metric.csv"


# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

###
# set up method parameters to test
###

class Args:
    def __init__(self):
        # latent
        self.latent_dim = 20

        # Training parameters
        self.batch_size = 128
        self.lr = 0.0001
        self.weight_decay = 2e-5
        self.dropout = 0.1
        self.epochs = 5000
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 1
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu
train_args = Args()

###
# set up data parameters to test
###

default_params = {
    "n_samples": 10000,
    "n_hidden_variables": 5,
    "hidden_dist_type": "poisson",
    "data_dim": 50,
    "latent_dim": 20,
    "nonlinearity_level": 1,
    "nonlinearity_type": "relu",
    "hidden_connectivity": 0.4,
    "data_sparsity": 0.0,
    "noise_variance": 0.0,
    "noise_mean": 0.0,
    "n_noise_components": 1
}

metrics_to_vary = ['R2', 'MSE', 'RMSE', "ExVarScore", 'McFaddenR2']
higher_is_better = {
    'R2': True,
    'MSE': False,
    'RMSE': False,
    'ExVarScore': True,
    'McFaddenR2': True
}

nonlinearities = ["relu", "polynomial", "sigmoid", "trigonometric"]
sparsities = [0.0, 0.1, 0.7]

seeds = [0, 42, 554, 9306, 89024]

run_iter = 0

# if the file already exists, check which run iters have been completed
if os.path.exists(out_file):
    existing_df = pd.read_csv(out_file)
    if "run_iter" in existing_df.columns:
        completed_iters = set(existing_df["run_iter"].unique())
        max_iter = max(completed_iters)
        run_iter_start = max_iter + 1
        print(f"Found existing file with completed run iters up to {max_iter}. Starting from run_iter {run_iter}.")
    else:
        print("Found existing file but no 'run_iter' column. Starting from run_iter 0.")
        run_iter_start = 0
else:
    print("No existing file found. Starting from run_iter 0.")
    run_iter_start = 0

# Generate all combinations of metrics with nonlinearities and sparsities
robustness_combinations = []

# Combine metrics with nonlinearities
for metric in metrics_to_vary:
    for nonlinearity in nonlinearities:
        robustness_combinations.append(("nonlinearity", (metric, nonlinearity)))

# Combine metrics with sparsities
for metric in metrics_to_vary:
    for sparsity in sparsities:
        robustness_combinations.append(("sparsity", (metric, sparsity)))

print(f"Total combinations to run: {len(robustness_combinations) * len(seeds)}")

for param_name, param_values in robustness_combinations:
    # Create modified data hyperparameters
    data_hyperparams = default_params.copy()
    if param_name == "nonlinearity":
        metric, param_value = param_values
        data_hyperparams["nonlinearity_type"] = param_value
    elif param_name == "sparsity":
        metric, param_value = param_values
        data_hyperparams["data_sparsity"] = param_value
    else:
        raise ValueError(f"Unknown param_name: {param_name}")
    
    if higher_is_better[metric]:
        threshold_type = 'absolute'
        method_hyperparameters = {
            "r_square_thresholds": 0.05,
            "early_stopping": 50,
            "rank_reduction_frequencies": 10,
            "rank_reduction_thresholds": 0.01,
            "patiences": 10,
        }
    else:
        threshold_type = 'relative'
        method_hyperparameters = {
            "r_square_thresholds": 1.05,
            "early_stopping": 50,
            "rank_reduction_frequencies": 10,
            "rank_reduction_thresholds": 0.001,
            "patiences": 10,
        }

    print(f"\n=== Testing {metric} + {param_name} {param_value} ===")

    # Create a tabular data simulator
    tab_sim = TabularDataSimulator(
        n_samples=data_hyperparams["n_samples"],
        n_hidden_variables=data_hyperparams["n_hidden_variables"],
        hidden_dist_type=data_hyperparams["hidden_dist_type"],
        data_dim=data_hyperparams["data_dim"],
        nonlinearity_level=data_hyperparams["nonlinearity_level"],
        nonlinearity_type=data_hyperparams["nonlinearity_type"],
        hidden_connectivity=data_hyperparams["hidden_connectivity"],
        data_sparsity=data_hyperparams["data_sparsity"],
        noise_variance=data_hyperparams["noise_variance"],
        noise_mean=data_hyperparams["noise_mean"],
        n_noise_components=data_hyperparams["n_noise_components"],
        random_seed=0,
    )
    tab_data, tab_hidden, tab_hidden_orig = tab_sim.generate_data()
    tab_data = torch.tensor(tab_data, dtype=torch.float32)

    for seed in seeds:

        if run_iter < run_iter_start:
            print(f"Skipping run_iter {run_iter} as it is already completed.")
            run_iter += 1
            continue

        # set the seed
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)

        model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae3(
            tab_data, 
            int(0.9 * data_hyperparams["n_samples"]),
            train_args.latent_dim,
            DEVICE,
            train_args,
            epochs=train_args.epochs, 
            lr=train_args.lr, 
            batch_size=train_args.batch_size, 
            ae_depth=train_args.ae_depth, 
            ae_width=train_args.ae_width, 
            dropout=train_args.dropout, 
            wd=train_args.weight_decay,
            early_stopping=method_hyperparameters["early_stopping"],
            initial_rank_ratio=1.0,
            rank_reduction_frequency=method_hyperparameters["rank_reduction_frequencies"],
            rank_reduction_threshold=method_hyperparameters["rank_reduction_thresholds"],
            warmup_epochs=method_hyperparameters["early_stopping"],
            patience=method_hyperparameters["patiences"],
            min_rank=1,
            r_square_threshold=method_hyperparameters["r_square_thresholds"],
            threshold_type=threshold_type,
            verbose=False,
            model_name=None,
            decision_metric=metric,
            higher_is_better=higher_is_better[metric]
        )

        temp_df = pd.DataFrame(rank_history)
        temp_df["final_ranks"] = rank_history["ranks"][-1]           

        try:
            # evaluate the predictability of the true hidden variables
            reg = LinearRegression()
            reg.fit(reps.cpu().numpy(), tab_hidden_orig[:int(0.9 * data_hyperparams["n_samples"])])
            # compute the R**2 of the regression fit
            r2 = reg.score(reps.cpu().numpy(), tab_hidden_orig[:int(0.9 * data_hyperparams["n_samples"])])
            temp_df["hidden_r2"] = r2
        except Exception as e:
            print(f"Error in regression: {e}")
            temp_df["hidden_r2"] = np.nan

        temp_df["param_name"] = param_name
        temp_df["param_value"] = param_value
        temp_df["metric"] = metric
        temp_df["seed"] = seed
        temp_df["run_iter"] = run_iter

        # if out_file exists, append to it, otherwise create it
        if os.path.exists(out_file):
            temp_df.to_csv(out_file, mode='a', header=False, index=False)
        else:
            temp_df.to_csv(out_file, mode='w', header=True, index=False)
        
        run_iter += 1


# at the end load the file and report the mean +- SEM final_ranks per metric
all_results = pd.read_csv(out_file)
import pandas as pd
import numpy as np
# Group by metric and param_name, then calculate mean and SEM
summary = all_results.groupby(['metric', 'param_name', 'param_value']).agg(
    mean_final_ranks=('final_ranks', 'mean'),
    sem_final_ranks=('final_ranks', lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))
).reset_index()
# Print the summary
print("\nMean ± SEM for each metric and parameter:")
for _, row in summary.iterrows():
    print(f"Metric: {row['metric']}, Param: {row['param_name']}={row['param_value']}, Final Ranks: {row['mean_final_ranks']:.2f} ± {row['sem_final_ranks']:.2f}")

# also report just the mean +- SEM final_ranks per metric
summary_metric = all_results.groupby(['metric']).agg(
    mean_final_ranks=('final_ranks', 'mean'),
    sem_final_ranks=('final_ranks', lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))
).reset_index()
# Print the summary
print("\nMean ± SEM for each metric:")
for _, row in summary_metric.iterrows():
    print(f"Metric: {row['metric']}, Final Ranks: {row['mean_final_ranks']:.2f} ± {row['sem_final_ranks']:.2f}")

# importance of parameters
# Calculate the importance of each parameter on rank prediction, grouped by metric and parameter name
param_importance = all_results.groupby(['metric', 'param_name']).agg(
    mean_deviation_from_rank_5=('final_ranks', lambda x: np.mean(np.abs(x - 5))),
    sem_final_ranks=('final_ranks', lambda x: np.std(x, ddof=1) / np.sqrt(len(x)))  # Calculate SEM
).reset_index()

# Print the importance of each parameter per metric
print("\nImportance of each parameter on rank prediction (mean deviation from rank 5), grouped by metric:")
for _, row in param_importance.iterrows():
    print(f"Metric: {row['metric']}, Param: {row['param_name']}, Mean Deviation: {row['mean_deviation_from_rank_5']:.2f} ± {row['sem_final_ranks']:.2f}")

# Add a 'total' parameter as the average of nonlinearity and sparsity
param_importance_total = param_importance[param_importance['param_name'].isin(['nonlinearity', 'sparsity'])]
param_importance_total = param_importance_total.groupby(['metric']).agg(
    mean_deviation_from_rank_5=('mean_deviation_from_rank_5', 'mean'),
    sem_final_ranks=('sem_final_ranks', lambda x: np.sqrt(np.sum(x**2)) / len(x))  # Combine SEMs
).reset_index()
param_importance_total['param_name'] = 'total'

# Append the 'total' parameter to the param_importance DataFrame
param_importance = pd.concat([param_importance, param_importance_total], ignore_index=True)

# Ensure the 'total' parameter is included in the plot
print("\nAdded 'total' parameter as the average of nonlinearity and sparsity.")

import matplotlib.pyplot as plt
import seaborn as sns

# Create a bar plot for the importance of parameters on rank prediction
plt.figure(figsize=(8, 4))

# Use seaborn to create the bar plot with error bars
sns.barplot(
    data=param_importance,
    x='metric',
    y='mean_deviation_from_rank_5',
    hue='param_name',
    ci=None  # Disable internal confidence intervals since we calculate SEM manually
)

# Add error bars manually using SEM, ensuring proper alignment with the bars
for i, row in param_importance.iterrows():
    # Find the x position of the bar for the current metric and parameter
    metric_idx = param_importance['metric'].unique().tolist().index(row['metric'])
    param_idx = param_importance['param_name'].unique().tolist().index(row['param_name'])
    bar_width = 0.8 / len(param_importance['param_name'].unique())  # Adjust bar width based on the number of parameters
    x_position = metric_idx - 0.4 + (param_idx + 0.5) * bar_width  # Center the error bar on the bar

    plt.errorbar(
        x=x_position,  # Adjusted x position for proper alignment
        y=row['mean_deviation_from_rank_5'],
        yerr=row['sem_final_ranks'],
        fmt='none',
        c='black',
        capsize=3
    )

# Customize the plot
plt.title('Importance of Parameters on Rank Prediction', fontsize=14)
plt.xlabel('Metric', fontsize=12)
plt.ylabel('Mean Deviation from Rank 5', fontsize=12)
plt.legend(title='Parameter Name', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()

# Save the plot in the specified directory
output_path = '03_results/paper_results/figures/param_importance_barplot.png'
plt.savefig(output_path, dpi=300)
print(f"Bar plot saved as '{output_path}'")

# Show the plot
plt.show()

