import numpy as np
import torch
import pandas as pd
import os
import random
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.mm_parametric_simulation import TabularMMDataSimulator, HierarchicalTabularMMDataSimulator, GeometricTabularMMDataSimulator
from src.functions.train_larrp_multimodal import train_overcomplete_ae, compute_classification

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--n_samples', type=int, default=10000, help='number of samples to use for the computation')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--paired', action='store_true', help='whether the data is paired (default: False)')
parser.add_argument('--data_version', type=str, default='small', help='version of the data to use (default: small). can be small or large')
parser.add_argument('--threshold', type=str, default='absolute', choices=['relative', 'absolute'], help='whether to use relative or absolute R2 thresholds (default: relative)')
args = parser.parse_args()

# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

noise = 0.0
sparsity = 0.0

run_name = f"larrp_mm-parametric-sim6-ortho_n-{args.n_samples}_rseed-{args.seed}_paired-{args.paired}_v-{args.data_version}"
out_file = f"03_results/reports/{run_name}.csv"

# if the file already exists, remove it
if os.path.exists(out_file):
    print(f"File {out_file} already exists, removing it.")
    os.remove(out_file)

if args.data_version == 'small':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'small-c':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['gumbel', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'small-d':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['poisson', 'geometric'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '10-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-b':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-c':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['gumbel', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-d':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'geometric'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [2, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced3':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class-gaussian',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class-gaussian',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large-c':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['gumbel', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large-d':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'geometric'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
else:
    raise ValueError("data_version must be either 'small' or 'large'")

tab_sim = TabularMMDataSimulator(
    n_samples=data_hyperparams['n_samples'],
    n_shared_variables=data_hyperparams['n_shared_variables'],
    n_hidden_variables=data_hyperparams['n_hidden_variables'],
    shared_hidden_dist_type=data_hyperparams['shared_hidden_dist_type'],
    hidden_dist_types=data_hyperparams['hidden_dist_types'],
    data_dims=data_hyperparams['data_dims'],
    nonlinearity_level=data_hyperparams['nonlinearity_level'],
    nonlinearity_type=data_hyperparams['nonlinearity_type'],
    hidden_connectivities=data_hyperparams['hidden_connectivities'],
    data_sparsity=data_hyperparams['data_sparsity'],
    noise_variance=data_hyperparams['noise_variance'],
    random_seed=data_hyperparams['random_seed'],
)
print("Data generated. Computing initial analysis.")

y1, y2, x0, x1, x2, labels = tab_sim.generate_data()
if args.data_version == 'large':
    # scale down y1 and y2 by 100
    #y1 = y1 / 10.0
    #y2 = y2 / 10.0
    # normalize y1 and y2 to have values between 0 and 10 (I think we need either this or loss balancing because the second modality is larger magnitude and ends up having higher rank)
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))
else:
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))

# save a control image
n_cols = 3 + 2
n_rows = 4  # Increased from 3 to 4 to add histogram row
reps = [y1, y2, x0, x1, x2]

accs = []
sils = []
preds = []
for j in tqdm(range(len(reps))):
    acc = compute_classification(reps[j], labels[:, 0])
    
    # Check if silhouette score can be computed
    unique_labels = np.unique(labels[:, 0])
    n_unique = len(unique_labels)
    n_samples = len(labels[:, 0])
    
    if 2 <= n_unique <= n_samples - 1:
        sil = silhouette_score(reps[j], labels[:, 0])
    else:
        print(f"Cannot compute silhouette score for latent {j} with {n_unique} unique labels and {n_samples} samples.")
        # If too many unique labels, bin them into discrete classes
        if n_unique > n_samples // 2:  # Too many unique values, likely continuous
            # Bin into 10 discrete classes
            from sklearn.preprocessing import KBinsDiscretizer
            discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
            binned_labels = discretizer.fit_transform(labels[:, 0].reshape(-1, 1)).flatten().astype(int)
            sil = silhouette_score(reps[j], binned_labels)
        else:
            sil = np.nan  # Cannot compute silhouette score
    
    accs.append(acc)
    sils.append(sil)

    temp_preds = []
    for i in [1,2]:
        # initialize linear regression
        reg = LinearRegression()
        reg.fit(reps[j], labels[:, i])
        # compute the R**2 of the regression fit
        r2 = reg.score(reps[j], labels[:, i])
        temp_preds.append(r2)
    preds.append(temp_preds)

print("Initial analysis complete.")

fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
for i in range(n_cols):
    pca = PCA(n_components=2)
    pca_reps = pca.fit_transform(reps[i])
    
    # First 3 rows: PCA plots colored by labels
    for j in range(3):  # Changed from n_rows to 3
        axs[j, i].scatter(pca_reps[:, 0], pca_reps[:, 1], c=labels[:, j], cmap='viridis', alpha=0.5)
        if i < 2:
            title_text = f'Data {i}, Label {j}'
        else:
            title_text = f'Latent {i-2}, Label {j}'
        if j == 0:
            title_text += f', Acc ({accs[i]:.2f}), Sil ({sils[i]:.2f})'
        else:
            title_text += f', Pred ({preds[i][j-1]:.2f})'
        axs[j, i].set_title(title_text)
        axs[j, i].set_xlabel('PC1')
        axs[j, i].set_ylabel('PC2')
    
    # Fourth row: Histograms of value distributions
    axs[3, i].hist(reps[i].flatten(), bins=100, alpha=0.7, edgecolor='black')
    if i < 2:
        hist_title = f'Data {i} Value Distribution'
    else:
        hist_title = f'Latent {i-2} Value Distribution'
    axs[3, i].set_title(hist_title)
    axs[3, i].set_xlabel('Value')
    axs[3, i].set_ylabel('Count')
    axs[3, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"03_results/plots/{run_name}_control.png")
plt.close(fig)
del reps

# set it in the right format
if not args.paired:
    y1_pad = np.zeros_like(y1)
    y2_pad = np.zeros_like(y2)
    y1 = np.concatenate([y1, y1_pad], axis=0)
    y2 = np.concatenate([y2_pad, y2], axis=0)
    labels = np.concatenate([labels, labels], axis=0)

data = [torch.FloatTensor(y1), torch.FloatTensor(y2)]

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
random.seed(args.seed)

# method parameters
if args.threshold == "relative":
    method_hyperparameters = {
        #"r_square_thresholds": [0.95, 0.9, 0.875, 0.85, 0.825, 0.8, 0.775, 0.75, 0.725, 0.7],
        #"r_square_thresholds": [0.95, 0.9, 0.85, 0.8, 0.75],
        "r_square_thresholds": [0.99, 0.9, 0.8],
        #"r_square_thresholds": [0.8],
        "early_stopping": [50],
        "rank_reduction_frequencies": [5],
        "rank_reduction_thresholds": [0.01],
        "patiences": [5],
    }
elif args.threshold == "absolute":
    method_hyperparameters = {
        #"r_square_thresholds": [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3],
        "r_square_thresholds": [0.05],
        #"r_square_thresholds": [0.2],
        "early_stopping": [50],
        "rank_reduction_frequencies": [10],
        "rank_reduction_thresholds": [0.01],
        "patiences": [10],
        "ortho_values": [0.001, 0.01, 0.1, 1.0, 10.0],
    }
rank_reduction_frequency = method_hyperparameters["rank_reduction_frequencies"][0]
rank_reduction_threshold = method_hyperparameters["rank_reduction_thresholds"][0]
early_stopping = method_hyperparameters["early_stopping"][0]
patience = method_hyperparameters["patiences"][0]
from itertools import product
method_combinations = list(product(*method_hyperparameters.values()))
print(f"Number of method combinations: {len(method_combinations)}")

class Args:
    def __init__(self):
        # latent
        self.latent_dim = 100

        # Training parameters
        self.batch_size = 128
        self.lr = 0.0001
        self.weight_decay = 1e-5
        self.dropout = 0.1
        self.epochs = 5000
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 1
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu
train_args = Args()

###
# start training multiple configs
###
config_counter = 0
for r_square_threshold in method_hyperparameters["r_square_thresholds"]:
    for ortho_weight in method_hyperparameters["ortho_values"]:
        config_counter += 1
        print(f"### Run {config_counter}/{len(method_combinations)} ###")
        model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae(
            data, 
            int(0.9 * args.n_samples),
            train_args.latent_dim, 
            DEVICE,
            train_args,
            epochs=train_args.epochs, 
            lr=train_args.lr, 
            batch_size=train_args.batch_size, 
            ae_depth=train_args.ae_depth, 
            ae_width=train_args.ae_width, 
            dropout=train_args.dropout, 
            wd=train_args.weight_decay,
            early_stopping=early_stopping,
            initial_rank_ratio=1.0,
            rank_reduction_frequency=rank_reduction_frequency,
            rank_reduction_threshold=rank_reduction_threshold,
            warmup_epochs=early_stopping,
            patience=patience,
            min_rank=1,
            r_square_threshold=r_square_threshold,
            threshold_type=args.threshold,
            compressibility_type='direct',
            verbose=False,
            #compute_jacobian=True,  # compute the contractive loss
            compute_jacobian=False,
            include_ortholoss=True,
            ortho_loss_anneal_epochs=100,
            ortho_loss_end_weight=ortho_weight,
            ortho_loss_warmup=100,
            #l2_norm_adaptivelayers=1e-3,
            sharedwhenall=False,
            model_name=run_name + str(config_counter)
        )

        # print the length of each list in rank_history
        #print(f"Rank history lengths: {[(key, len(rank_history[key])) for key in rank_history.keys()]}")

        temp_df = pd.DataFrame(rank_history)
        #temp_df["train_loss"] = [loss_curves[0][e] for e in temp_df["epoch"]]
        #temp_df["val_loss"] = [loss_curves[1][e] for e in temp_df["epoch"]]

        # add all the data and method parameters to the dataframe
        temp_df["r_square_threshold"] = r_square_threshold
        temp_df["early_stopping"] = early_stopping
        temp_df["rank_reduction_frequency"] = rank_reduction_frequency
        temp_df["rank_reduction_threshold"] = rank_reduction_threshold
        temp_df["patience"] = patience

        # add the most important final metrics
        valid_rsquares = [rank_history[f"rsquare {j}"] for j in range(len(data))]
        temp_df["r_squares_init"] = ', '.join([str(valid_rsquares[0][0]), str(valid_rsquares[1][0])])
        temp_df["r_squares_final"] = ', '.join([str(valid_rsquares[0][-1]), str(valid_rsquares[1][-1])])
        temp_df["final_ranks"] = rank_history["ranks"][-1]
        temp_df["config"] = config_counter
        temp_df["ortho_weight"] = ortho_weight

        # calculate classification accuracy and silhouette score on the latents for label 0
        accs = []
        sils = []
        for j in range(len(reps)):
            acc = compute_classification(reps[j].cpu().numpy(), labels[:, 0])
            
            # Check if silhouette score can be computed
            unique_labels = np.unique(labels[:, 0])
            n_unique = len(unique_labels)
            n_samples = len(labels[:, 0])
            
            if 2 <= n_unique <= n_samples - 1:
                sil = silhouette_score(reps[j].cpu().numpy(), labels[:, 0])
            else:
                # If too many unique labels, bin them into discrete classes
                if n_unique > n_samples // 2:  # Too many unique values, likely continuous
                    # Bin into 10 discrete classes
                    from sklearn.preprocessing import KBinsDiscretizer
                    discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
                    binned_labels = discretizer.fit_transform(labels[:, 0].reshape(-1, 1)).flatten().astype(int)
                    sil = silhouette_score(reps[j].cpu().numpy(), binned_labels)
                else:
                    sil = np.nan  # Cannot compute silhouette score
            
            accs.append(acc)
            sils.append(sil)
        temp_df["classification_accuracy"] = ', '.join([str(a) for a in accs])
        temp_df["silhouette_score"] = ', '.join([str(s) for s in sils])
        print(f"Classification: {[str(a) for a in accs]}")

        # predictability of labels 1 and 2 from each latent
        # these labels are the means of hidden modality 1 and 2, so we need regression
        preds = []
        for j in [1,2]:
            # initialize linear regression
            temp_preds = []
            for i in range(len(reps)):
                reg = LinearRegression()
                reg.fit(reps[i].cpu().numpy(), labels[:, j])
                # compute the R**2 of the regression fit
                r2 = reg.score(reps[i].cpu().numpy(), labels[:, j])
                temp_preds.append(r2)
            preds.append(temp_preds)
            temp_df[f"label_{j}_pred"] = ', '.join([str(p) for p in temp_preds])
            print(f"Label {j} Prediction: {[str(p) for p in temp_preds]}")

        # if out_file exists, append to it, otherwise create it
        if os.path.exists(out_file):
            temp_df.to_csv(out_file, mode='a', header=False, index=False)
        else:
            temp_df.to_csv(out_file, mode='w', header=True, index=False)
        
        try:
            # create PCA plots of the reps colored by the labels and save them
            n_cols = len(reps)
            n_rows = labels.shape[1]
            fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
            for i in range(n_cols):
                if reps[i].shape[1] > 2:
                    pca = PCA(n_components=2)
                    pca_reps = pca.fit_transform(reps[i].cpu().numpy())
                elif reps[i].shape[1] == 2:
                    pca_reps = reps[i].cpu().numpy()
                else:
                    pca_reps = np.zeros((reps[i].shape[0], 2))  # handle edge case with single sample
                    pca_reps[:, 0] = reps[i].cpu().numpy().flatten()
                for j in range(n_rows):
                    axs[j, i].scatter(pca_reps[:, 0], pca_reps[:, 1], c=labels[:, j], cmap='viridis', alpha=0.5)
                    if j == 0:
                        axs[j, i].set_title(f'Latent {i}, Label {j}, Acc ({accs[i]:.2f}), Sil ({sils[i]:.2f})')
                    else:
                        axs[j, i].set_title(f'Latent {i}, Label {j}, Pred ({preds[j-1][i]:.2f})')
                    axs[j, i].set_xlabel('PC1')
                    axs[j, i].set_ylabel('PC2')
            plt.tight_layout()
            plt.savefig(f"03_results/plots/{run_name}_config-{config_counter}.png")
            plt.close(fig)
        except:
            print("Error in plotting PCA of latents.")