import numpy as np
import torch
import pandas as pd
import os
import random
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.mm_parametric_simulation import TabularMMDataSimulator, HierarchicalTabularMMDataSimulator, GeometricTabularMMDataSimulator
from src.functions.train_larrp_multimodal import train_overcomplete_ae, compute_classification

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--n_samples', type=int, default=10000, help='number of samples to use for the computation')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--paired', action='store_true', help='whether the data is paired (default: False)')
parser.add_argument('--data_version', type=str, default='small', help='version of the data to use (default: small). can be small or large')
parser.add_argument('--threshold', type=str, default='relative', choices=['relative', 'absolute'], help='whether to use relative or absolute R2 thresholds (default: relative)')
args = parser.parse_args()

# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

noise = 0.0
sparsity = 0.0

run_name = f"larrp_mm-parametric-sim4_n-{args.n_samples}_rseed-{args.seed}_paired-{args.paired}_v-{args.data_version}"
out_file = f"03_results/reports/{run_name}.csv"

# if the file already exists, remove it
if os.path.exists(out_file):
    print(f"File {out_file} already exists, removing it.")
    os.remove(out_file)

if args.data_version == 'small':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'small-c':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['gumbel', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'small-d':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['poisson', 'geometric'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '10-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-b':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-c':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['gumbel', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-d':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'geometric'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [2, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced3':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class-gaussian',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class-gaussian',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large-c':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['gumbel', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large-d':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'geometric'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large2-separate':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 0,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
else:
    raise ValueError("data_version must be either 'small' or 'large'")

tab_sim = TabularMMDataSimulator(
    n_samples=data_hyperparams['n_samples'],
    n_shared_variables=data_hyperparams['n_shared_variables'],
    n_hidden_variables=data_hyperparams['n_hidden_variables'],
    shared_hidden_dist_type=data_hyperparams['shared_hidden_dist_type'],
    hidden_dist_types=data_hyperparams['hidden_dist_types'],
    data_dims=data_hyperparams['data_dims'],
    nonlinearity_level=data_hyperparams['nonlinearity_level'],
    nonlinearity_type=data_hyperparams['nonlinearity_type'],
    hidden_connectivities=data_hyperparams['hidden_connectivities'],
    data_sparsity=data_hyperparams['data_sparsity'],
    noise_variance=data_hyperparams['noise_variance'],
    random_seed=data_hyperparams['random_seed'],
)
print("Data generated. Computing initial analysis.")

y1, y2, x0, x1, x2, labels = tab_sim.generate_data()
if args.data_version == 'large':
    # scale down y1 and y2 by 100
    #y1 = y1 / 10.0
    #y2 = y2 / 10.0
    # normalize y1 and y2 to have values between 0 and 10 (I think we need either this or loss balancing because the second modality is larger magnitude and ends up having higher rank)
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))
else:
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))

# save a control image
n_cols = 3 + 2
n_rows = 4  # Increased from 3 to 4 to add histogram row
reps = [y1, y2, x0, x1, x2]

accs = []
sils = []
preds = []
for j in tqdm(range(len(reps))):
    if args.data_version != 'large2-separate':
        acc = compute_classification(reps[j], labels[:, 0])
    else:
        acc = np.nan
    
    # Check if silhouette score can be computed
    unique_labels = np.unique(labels[:, 0])
    n_unique = len(unique_labels)
    n_samples = len(labels[:, 0])
    
    if 2 <= n_unique <= n_samples - 1:
        sil = silhouette_score(reps[j], labels[:, 0])
    else:
        print(f"Cannot compute silhouette score for latent {j} with {n_unique} unique labels and {n_samples} samples.")
        # If too many unique labels, bin them into discrete classes
        if n_unique > n_samples // 2:  # Too many unique values, likely continuous
            # Bin into 10 discrete classes
            from sklearn.preprocessing import KBinsDiscretizer
            discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
            binned_labels = discretizer.fit_transform(labels[:, 0].reshape(-1, 1)).flatten().astype(int)
            sil = silhouette_score(reps[j], binned_labels)
        else:
            sil = np.nan  # Cannot compute silhouette score
    
    accs.append(acc)
    sils.append(sil)

    temp_preds = []
    if (args.data_version != 'large2-separate') and (not j == 0):
        for i in [1,2]:
            # initialize linear regression
            reg = LinearRegression()
            reg.fit(reps[j], labels[:, i])
            # compute the R**2 of the regression fit
            r2 = reg.score(reps[j], labels[:, i])
            temp_preds.append(r2)
    else:
        for i in [1,2]:
            temp_preds.append(np.nan)
    preds.append(temp_preds)

print("Initial analysis complete.")

fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
for i in range(n_cols):
    if reps[i].shape[1] == 0:
        pca_reps = np.zeros((reps[i].shape[0], 2))
    else:
        pca = PCA(n_components=2)
        pca_reps = pca.fit_transform(reps[i])
    
    # First 3 rows: PCA plots colored by labels
    for j in range(3):  # Changed from n_rows to 3
        axs[j, i].scatter(pca_reps[:, 0], pca_reps[:, 1], c=labels[:, j], cmap='viridis', alpha=0.5)
        if i < 2:
            title_text = f'Data {i}, Label {j}'
        else:
            title_text = f'Latent {i-2}, Label {j}'
        if j == 0:
            title_text += f', Acc ({accs[i]:.2f}), Sil ({sils[i]:.2f})'
        else:
            title_text += f', Pred ({preds[i][j-1]:.2f})'
        axs[j, i].set_title(title_text)
        axs[j, i].set_xlabel('PC1')
        axs[j, i].set_ylabel('PC2')
    
    # Fourth row: Histograms of value distributions
    axs[3, i].hist(reps[i].flatten(), bins=100, alpha=0.7, edgecolor='black')
    if i < 2:
        hist_title = f'Data {i} Value Distribution'
    else:
        hist_title = f'Latent {i-2} Value Distribution'
    axs[3, i].set_title(hist_title)
    axs[3, i].set_xlabel('Value')
    axs[3, i].set_ylabel('Count')
    axs[3, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"03_results/plots/{run_name}_control.png")
plt.close(fig)
del reps

# set it in the right format
if not args.paired:
    y1_pad = np.zeros_like(y1)
    y2_pad = np.zeros_like(y2)
    y1 = np.concatenate([y1, y1_pad], axis=0)
    y2 = np.concatenate([y2_pad, y2], axis=0)
    labels = np.concatenate([labels, labels], axis=0)

data = [torch.FloatTensor(y1), torch.FloatTensor(y2)]

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
random.seed(args.seed)

# method parameters
if args.threshold == "relative":
    method_hyperparameters = {
        #"r_square_thresholds": [0.95, 0.9, 0.875, 0.85, 0.825, 0.8, 0.775, 0.75, 0.725, 0.7],
        #"r_square_thresholds": [0.95, 0.9, 0.85, 0.8, 0.75],
        "r_square_thresholds": [0.99, 0.9, 0.8],
        #"r_square_thresholds": [0.8],
        "early_stopping": [50],
        "rank_reduction_frequencies": [5],
        "rank_reduction_thresholds": [0.01],
        "patiences": [5],
    }
elif args.threshold == "absolute":
    method_hyperparameters = {
        #"r_square_thresholds": [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3],
        "r_square_thresholds": [0.001, 0.0025, 0.005, 0.0075, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2],
        #"r_square_thresholds": [0.001, 0.005, 0.01, 0.05, 0.1],
        #"r_square_thresholds": [0.05],
        #"r_square_thresholds": [0.2],
        "early_stopping": [50],
        "rank_reduction_frequencies": [10],
        "rank_reduction_thresholds": [0.01],
        "patiences": [10],
    }
from itertools import product
method_combinations = list(product(*method_hyperparameters.values()))
print(f"Number of method combinations: {len(method_combinations)}")

class Args:
    def __init__(self):
        # latent
        self.latent_dim = 100

        # Training parameters
        self.batch_size = 128
        self.lr = 0.0001
        self.weight_decay = 1e-5
        self.dropout = 0.1
        self.epochs = 5000
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 1
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu
train_args = Args()

###
# start training multiple configs
###
config_counter = 0
for r_square_threshold in method_hyperparameters["r_square_thresholds"]:
    for early_stopping in method_hyperparameters["early_stopping"]:
        for rank_reduction_frequency in method_hyperparameters["rank_reduction_frequencies"]:
            for rank_reduction_threshold in method_hyperparameters["rank_reduction_thresholds"]:
                for patience in method_hyperparameters["patiences"]:
                    config_counter += 1
                    print(f"### Run {config_counter}/{len(method_combinations)} ###")
                    model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae(
                        data, 
                        int(0.9 * args.n_samples),
                        train_args.latent_dim, 
                        DEVICE,
                        train_args,
                        epochs=train_args.epochs, 
                        lr=train_args.lr, 
                        batch_size=train_args.batch_size, 
                        ae_depth=train_args.ae_depth, 
                        ae_width=train_args.ae_width, 
                        dropout=train_args.dropout, 
                        wd=train_args.weight_decay,
                        early_stopping=early_stopping,
                        initial_rank_ratio=1.0,
                        rank_reduction_frequency=rank_reduction_frequency,
                        rank_reduction_threshold=rank_reduction_threshold,
                        warmup_epochs=early_stopping,
                        patience=patience,
                        min_rank=1,
                        #min_rank=0,
                        r_square_threshold=r_square_threshold,
                        threshold_type=args.threshold,
                        compressibility_type='direct',
                        verbose=False,
                        #compute_jacobian=True,  # compute the contractive loss
                        compute_jacobian=False,
                        #include_ortholoss=True,
                        #ortho_loss_anneal_epochs=1000,
                        #ortho_loss_end_weight=0.1,
                        #ortho_loss_warmup=100,
                        #l2_norm_adaptivelayers=1e-3,
                        sharedwhenall=False,
                        model_name=run_name + str(config_counter)
                    )

                    # print the length of each list in rank_history
                    #print(f"Rank history lengths: {[(key, len(rank_history[key])) for key in rank_history.keys()]}")

                    temp_df = pd.DataFrame(rank_history)
                    #temp_df["train_loss"] = [loss_curves[0][e] for e in temp_df["epoch"]]
                    #temp_df["val_loss"] = [loss_curves[1][e] for e in temp_df["epoch"]]

                    # add all the data and method parameters to the dataframe
                    temp_df["r_square_threshold"] = r_square_threshold
                    temp_df["early_stopping"] = early_stopping
                    temp_df["rank_reduction_frequency"] = rank_reduction_frequency
                    temp_df["rank_reduction_threshold"] = rank_reduction_threshold
                    temp_df["patience"] = patience

                    # add the most important final metrics
                    valid_rsquares = [rank_history[f"rsquare {j}"] for j in range(len(data))]
                    temp_df["r_squares_init"] = ', '.join([str(valid_rsquares[0][0]), str(valid_rsquares[1][0])])
                    temp_df["r_squares_final"] = ', '.join([str(valid_rsquares[0][-1]), str(valid_rsquares[1][-1])])
                    temp_df["final_ranks"] = rank_history["ranks"][-1]
                    temp_df["config"] = config_counter

                    # calculate classification accuracy and silhouette score on the latents for label 0
                    accs = []
                    sils = []
                    for j in range(len(reps)):
                        if args.data_version != 'large2-separate':
                            acc = compute_classification(reps[j].cpu().numpy(), labels[:, 0])
                        else:
                            acc = np.nan

                        # Check if silhouette score can be computed
                        unique_labels = np.unique(labels[:, 0])
                        n_unique = len(unique_labels)
                        n_samples = len(labels[:, 0])
                        
                        if 2 <= n_unique <= n_samples - 1:
                            sil = silhouette_score(reps[j].cpu().numpy(), labels[:, 0])
                        else:
                            # If too many unique labels, bin them into discrete classes
                            if n_unique > n_samples // 2:  # Too many unique values, likely continuous
                                # Bin into 10 discrete classes
                                from sklearn.preprocessing import KBinsDiscretizer
                                discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
                                binned_labels = discretizer.fit_transform(labels[:, 0].reshape(-1, 1)).flatten().astype(int)
                                sil = silhouette_score(reps[j].cpu().numpy(), binned_labels)
                            else:
                                sil = np.nan  # Cannot compute silhouette score
                        
                        accs.append(acc)
                        sils.append(sil)
                    temp_df["classification_accuracy"] = ', '.join([str(a) for a in accs])
                    temp_df["silhouette_score"] = ', '.join([str(s) for s in sils])
                    print(f"Classification: {[str(a) for a in accs]}")

                    # predictability of labels 1 and 2 from each latent
                    # these labels are the means of hidden modality 1 and 2, so we need regression
                    preds = []
                    for j in [1,2]:
                        # initialize linear regression
                        temp_preds = []
                        for i in range(len(reps)):
                            reg = LinearRegression()
                            reg.fit(reps[i].cpu().numpy(), labels[:, j])
                            # compute the R**2 of the regression fit
                            r2 = reg.score(reps[i].cpu().numpy(), labels[:, j])
                            temp_preds.append(r2)
                        preds.append(temp_preds)
                        temp_df[f"label_{j}_pred"] = ', '.join([str(p) for p in temp_preds])
                        print(f"Label {j} Prediction: {[str(p) for p in temp_preds]}")

                    # if out_file exists, append to it, otherwise create it
                    if os.path.exists(out_file):
                        temp_df.to_csv(out_file, mode='a', header=False, index=False)
                    else:
                        temp_df.to_csv(out_file, mode='w', header=True, index=False)
                    
                    try:
                        # create PCA plots of the reps colored by the labels and save them
                        n_cols = len(reps)
                        n_rows = labels.shape[1]
                        fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
                        for i in range(n_cols):
                            if reps[i].shape[1] > 2:
                                pca = PCA(n_components=2)
                                pca_reps = pca.fit_transform(reps[i].cpu().numpy())
                            elif reps[i].shape[1] == 2:
                                pca_reps = reps[i].cpu().numpy()
                            else:
                                pca_reps = np.zeros((reps[i].shape[0], 2))  # handle edge case with single sample
                                pca_reps[:, 0] = reps[i].cpu().numpy().flatten()
                            for j in range(n_rows):
                                axs[j, i].scatter(pca_reps[:, 0], pca_reps[:, 1], c=labels[:, j], cmap='viridis', alpha=0.5)
                                if j == 0:
                                    axs[j, i].set_title(f'Latent {i}, Label {j}, Acc ({accs[i]:.2f}), Sil ({sils[i]:.2f})')
                                else:
                                    axs[j, i].set_title(f'Latent {i}, Label {j}, Pred ({preds[j-1][i]:.2f})')
                                axs[j, i].set_xlabel('PC1')
                                axs[j, i].set_ylabel('PC2')
                        plt.tight_layout()
                        plt.savefig(f"03_results/plots/{run_name}_config-{config_counter}.png")
                        plt.close(fig)
                    except:
                        print("Error in plotting PCA of latents.")