import numpy as np
import torch
import pandas as pd
import os
import random
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.mm_parametric_simulation import TabularMMDataSimulator, HierarchicalTabularMMDataSimulator, GeometricTabularMMDataSimulator
from src.functions.train_larrp_multimodal_vae import compute_classification, train_multimodal_vae

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--n_samples', type=int, default=10000, help='number of samples to use for the computation')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--paired', action='store_true', help='whether the data is paired (default: False)')
parser.add_argument('--data_version', type=str, default='small', help='version of the data to use (default: small). can be small or large')
parser.add_argument('--threshold', type=str, default='absolute', choices=['relative', 'absolute'], help='whether to use relative or absolute R2 thresholds (default: relative)')
parser.add_argument('--kl_anneal_epochs', type=int, default=100, help='number of epochs over which to anneal the KL divergence weight (default: 100)')
parser.add_argument('--kl_weight', type=float, default=1.0, help='initial KL divergence weight (default: 1.0)')
args = parser.parse_args()

# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

noise = 0.0
sparsity = 0.0

run_name = f"crossmodal_n-{args.n_samples}_rseed-{args.seed}_paired-{args.paired}_v-{args.data_version}_klw-{args.kl_weight}_klae-{args.kl_anneal_epochs}"
out_file = f"03_results/reports/crossmodal/{run_name}.csv"

# if the file already exists, remove it
if os.path.exists(out_file):
    print(f"File {out_file} already exists, removing it.")
    os.remove(out_file)

if args.data_version == 'small':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '10-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-b':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [2, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
else:
    raise ValueError("data_version must be either 'small' or 'large'")

tab_sim = TabularMMDataSimulator(
    n_samples=data_hyperparams['n_samples'],
    n_shared_variables=data_hyperparams['n_shared_variables'],
    n_hidden_variables=data_hyperparams['n_hidden_variables'],
    shared_hidden_dist_type=data_hyperparams['shared_hidden_dist_type'],
    hidden_dist_types=data_hyperparams['hidden_dist_types'],
    data_dims=data_hyperparams['data_dims'],
    nonlinearity_level=data_hyperparams['nonlinearity_level'],
    nonlinearity_type=data_hyperparams['nonlinearity_type'],
    hidden_connectivities=data_hyperparams['hidden_connectivities'],
    data_sparsity=data_hyperparams['data_sparsity'],
    noise_variance=data_hyperparams['noise_variance'],
    random_seed=data_hyperparams['random_seed'],
)
print("Data generated. Computing initial analysis.")

y1, y2, x0, x1, x2, labels = tab_sim.generate_data()
if args.data_version == 'large':
    # scale down y1 and y2 by 100
    #y1 = y1 / 10.0
    #y2 = y2 / 10.0
    # normalize y1 and y2 to have values between 0 and 10 (I think we need either this or loss balancing because the second modality is larger magnitude and ends up having higher rank)
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))
else:
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))

# save a control image
n_cols = 3 + 2
n_rows = 4  # Increased from 3 to 4 to add histogram row
reps = [y1, y2, x0, x1, x2]

accs = []
sils = []
preds = []
for j in tqdm(range(len(reps))):
    acc = compute_classification(reps[j], labels[:, 0])
    
    # Check if silhouette score can be computed
    unique_labels = np.unique(labels[:, 0])
    n_unique = len(unique_labels)
    n_samples = len(labels[:, 0])
    
    if 2 <= n_unique <= n_samples - 1:
        sil = silhouette_score(reps[j], labels[:, 0])
    else:
        print(f"Cannot compute silhouette score for latent {j} with {n_unique} unique labels and {n_samples} samples.")
        # If too many unique labels, bin them into discrete classes
        if n_unique > n_samples // 2:  # Too many unique values, likely continuous
            # Bin into 10 discrete classes
            from sklearn.preprocessing import KBinsDiscretizer
            discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
            binned_labels = discretizer.fit_transform(labels[:, 0].reshape(-1, 1)).flatten().astype(int)
            sil = silhouette_score(reps[j], binned_labels)
        else:
            sil = np.nan  # Cannot compute silhouette score
    
    accs.append(acc)
    sils.append(sil)

    temp_preds = []
    for i in [1,2]:
        # initialize linear regression
        reg = LinearRegression()
        reg.fit(reps[j], labels[:, i])
        # compute the R**2 of the regression fit
        r2 = reg.score(reps[j], labels[:, i])
        temp_preds.append(r2)
    preds.append(temp_preds)

print("Initial analysis complete.")

fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
for i in range(n_cols):
    pca = PCA(n_components=2)
    pca_reps = pca.fit_transform(reps[i])
    
    # First 3 rows: PCA plots colored by labels
    for j in range(3):  # Changed from n_rows to 3
        axs[j, i].scatter(pca_reps[:, 0], pca_reps[:, 1], c=labels[:, j], cmap='viridis', alpha=0.5)
        if i < 2:
            title_text = f'Data {i}, Label {j}'
        else:
            title_text = f'Latent {i-2}, Label {j}'
        if j == 0:
            title_text += f', Acc ({accs[i]:.2f}), Sil ({sils[i]:.2f})'
        else:
            title_text += f', Pred ({preds[i][j-1]:.2f})'
        axs[j, i].set_title(title_text)
        axs[j, i].set_xlabel('PC1')
        axs[j, i].set_ylabel('PC2')
    
    # Fourth row: Histograms of value distributions
    axs[3, i].hist(reps[i].flatten(), bins=100, alpha=0.7, edgecolor='black')
    if i < 2:
        hist_title = f'Data {i} Value Distribution'
    else:
        hist_title = f'Latent {i-2} Value Distribution'
    axs[3, i].set_title(hist_title)
    axs[3, i].set_xlabel('Value')
    axs[3, i].set_ylabel('Count')
    axs[3, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"03_results/plots/{run_name}_control.png")
plt.close(fig)
del reps

# set it in the right format
if not args.paired:
    y1_pad = np.zeros_like(y1)
    y2_pad = np.zeros_like(y2)
    y1 = np.concatenate([y1, y1_pad], axis=0)
    y2 = np.concatenate([y2_pad, y2], axis=0)
    labels = np.concatenate([labels, labels], axis=0)

data = [torch.FloatTensor(y1), torch.FloatTensor(y2)]

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
random.seed(args.seed)

# method parameters
if args.threshold == "relative":
    method_hyperparameters = {
        #"r_square_thresholds": [0.95, 0.9, 0.875, 0.85, 0.825, 0.8, 0.775, 0.75, 0.725, 0.7],
        #"r_square_thresholds": [0.95, 0.9, 0.85, 0.8, 0.75],
        "r_square_thresholds": [0.99, 0.9, 0.8],
        #"r_square_thresholds": [0.8],
        "early_stopping": [50],
        "rank_reduction_frequencies": [5],
        "rank_reduction_thresholds": [0.01],
        "patiences": [5],
    }
elif args.threshold == "absolute":
    method_hyperparameters = {
        #"r_square_thresholds": [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3],
        "r_square_thresholds": [0.05],
        #"r_square_thresholds": [0.2],
        "early_stopping": [50],
        "rank_reduction_frequencies": [10],
        "rank_reduction_thresholds": [0.01],
        "patiences": [10],
    }
from itertools import product
method_combinations = list(product(*method_hyperparameters.values()))
print(f"Number of method combinations: {len(method_combinations)}")

class Args:
    def __init__(self):
        # latent
        self.latent_dim = 100

        # Training parameters
        self.batch_size = 128
        self.lr = 0.0001
        self.weight_decay = 1e-5
        self.dropout = 0.1
        self.epochs = 5000
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 1
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu
train_args = Args()

###
# start training multiple configs
###
config_counter = 0
for r_square_threshold in method_hyperparameters["r_square_thresholds"]:
    for early_stopping in method_hyperparameters["early_stopping"]:
        for rank_reduction_frequency in method_hyperparameters["rank_reduction_frequencies"]:
            for rank_reduction_threshold in method_hyperparameters["rank_reduction_thresholds"]:
                for patience in method_hyperparameters["patiences"]:
                    config_counter += 1
                    print(f"### Run {config_counter}/{len(method_combinations)} ###")
                    model, reps, train_loss, r_squares, rank_history, loss_curves = train_multimodal_vae(
                        data, 
                        int(0.9 * args.n_samples),
                        train_args.latent_dim, 
                        DEVICE,
                        train_args,
                        epochs=train_args.epochs, 
                        lr=train_args.lr, 
                        batch_size=train_args.batch_size, 
                        ae_depth=train_args.ae_depth, 
                        ae_width=train_args.ae_width, 
                        dropout=train_args.dropout, 
                        wd=train_args.weight_decay,
                        early_stopping=early_stopping,
                        initial_rank_ratio=1.0,
                        rank_reduction_frequency=rank_reduction_frequency,
                        rank_reduction_threshold=rank_reduction_threshold,
                        warmup_epochs=early_stopping,
                        patience=patience,
                        min_rank=1,
                        r_square_threshold=r_square_threshold,
                        threshold_type=args.threshold,
                        kl_weight=args.kl_weight,  # KL weight for VAE
                        kl_anneal_epochs=args.kl_anneal_epochs,  # KL warmup from 0 to 1 over specified epochs
                        verbose=False,
                        model_name=run_name + str(config_counter)
                    )
                    # print the final ranks
                    print(f"Final ranks: {rank_history['ranks'][-1]}")

                    # print the length of each list in rank_history
                    #print(f"Rank history lengths: {[(key, len(rank_history[key])) for key in rank_history.keys()]}")

                    temp_df = pd.DataFrame(rank_history)
                    #temp_df["train_loss"] = [loss_curves[0][e] for e in temp_df["epoch"]]
                    #temp_df["val_loss"] = [loss_curves[1][e] for e in temp_df["epoch"]]

                    # add all the data and method parameters to the dataframe
                    temp_df["r_square_threshold"] = r_square_threshold
                    temp_df["early_stopping"] = early_stopping
                    temp_df["rank_reduction_frequency"] = rank_reduction_frequency
                    temp_df["rank_reduction_threshold"] = rank_reduction_threshold
                    temp_df["patience"] = patience

                    # add the most important final metrics
                    valid_rsquares = [rank_history[f"rsquare {j}"] for j in range(len(data))]
                    temp_df["r_squares_init"] = ', '.join([str(valid_rsquares[0][0]), str(valid_rsquares[1][0])])
                    temp_df["r_squares_final"] = ', '.join([str(valid_rsquares[0][-1]), str(valid_rsquares[1][-1])])
                    temp_df["final_ranks"] = str(rank_history["ranks"][-1])
                    temp_df["config"] = config_counter

                    # calculate classification accuracy and silhouette score on the latents for label 0
                    accs = []
                    sils = []
                    n_samples_train = int(0.9 * args.n_samples)
                    for j in range(len(reps)):
                        acc = compute_classification(reps[j].cpu().numpy(), labels[:n_samples_train, 0])
                        
                        # Check if silhouette score can be computed
                        unique_labels = np.unique(labels[:n_samples_train, 0])
                        n_unique = len(unique_labels)
                        n_samples = len(labels[:n_samples_train, 0])
                        
                        if 2 <= n_unique <= n_samples - 1:
                            sil = silhouette_score(reps[j].cpu().numpy(), labels[:n_samples_train, 0])
                        else:
                            # If too many unique labels, bin them into discrete classes
                            if n_unique > n_samples // 2:  # Too many unique values, likely continuous
                                # Bin into 10 discrete classes
                                from sklearn.preprocessing import KBinsDiscretizer
                                discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
                                binned_labels = discretizer.fit_transform(labels[:n_samples_train, 0].reshape(-1, 1)).flatten().astype(int)
                                sil = silhouette_score(reps[j].cpu().numpy(), binned_labels)
                            else:
                                sil = np.nan  # Cannot compute silhouette score
                        
                        accs.append(acc)
                        sils.append(sil)
                    temp_df["classification_accuracy"] = ', '.join([str(a) for a in accs])
                    temp_df["silhouette_score"] = ', '.join([str(s) for s in sils])
                    print(f"Classification: {[str(a) for a in accs]}")

                    # predictability of labels 1 and 2 from each latent
                    # these labels are the means of hidden modality 1 and 2, so we need regression
                    preds = []
                    for j in [1,2]:
                        # initialize linear regression
                        temp_preds = []
                        for i in range(len(reps)):
                            reg = LinearRegression()
                            reg.fit(reps[i].cpu().numpy(), labels[:n_samples_train, j])
                            # compute the R**2 of the regression fit
                            r2 = reg.score(reps[i].cpu().numpy(), labels[:n_samples_train, j])
                            temp_preds.append(r2)
                        preds.append(temp_preds)
                        temp_df[f"label_{j}_pred"] = ', '.join([str(p) for p in temp_preds])
                        print(f"Label {j} Prediction: {[str(p) for p in temp_preds]}")
                    
                    # ============================================================================
                    # Cross-modal generation and reconstruction saving
                    # ============================================================================
                    print("\nPerforming cross-modal generation...")
                    
                    # Get validation data
                    val_start = int(0.9 * args.n_samples)
                    val_end = args.n_samples
                    val_data = [d[val_start:val_end].to(DEVICE) for d in data]
                    
                    model.eval()
                    with torch.no_grad():
                        # 1. Get normal reconstructions for validation data
                        val_recon, vae_params = model(val_data)
                        
                        # Save validation reconstructions
                        val_recon_dict = {
                            f'recon_mod{i}': val_recon[i].cpu().numpy() 
                            for i in range(len(val_recon))
                        }
                        val_recon_dict.update({
                            f'original_mod{i}': val_data[i].cpu().numpy() 
                            for i in range(len(val_data))
                        })
                        
                        # 2. Perform cross-modal generation
                        # Encode validation data to get latent parameters
                        shared_mu, shared_logvar, specific_mus, specific_logvars, _ = model.encode(val_data)
                        
                        # Sample shared latent (use the mean for deterministic cross-generation)
                        z_shared = shared_mu  # Could also use reparameterization: model.reparameterize(shared_mu, shared_logvar)
                        
                        # For each modality, generate cross-modal predictions
                        cross_modal_generations = {}
                        
                        for source_mod in range(len(val_data)):
                            for target_mod in range(len(val_data)):
                                if source_mod == target_mod:
                                    continue  # Skip same-modality (that's just reconstruction)
                                
                                # Use shared latent from source modality
                                # Sample from prior for target modality-specific space
                                batch_size = z_shared.size(0)
                                target_dim = specific_mus[target_mod].size(1)
                                
                                # Sample from N(0, I) for the target modality-specific latent
                                z_specific_target_prior = torch.randn(batch_size, target_dim).to(DEVICE)
                                
                                # Decode: shared (from source) + specific (sampled from prior)
                                cross_gen = model.decode(z_shared, [z_specific_target_prior if i == target_mod 
                                                                     else torch.zeros_like(specific_mus[i]) 
                                                                     for i in range(len(val_data))])
                                
                                # Save the cross-generated target modality
                                cross_modal_generations[f'cross_mod{source_mod}_to_mod{target_mod}'] = cross_gen[target_mod].cpu().numpy()
                        
                        # 3. Save all results
                        output_dir = f"03_results/cross_modal_generations/"
                        os.makedirs(output_dir, exist_ok=True)
                        
                        output_file = f"{output_dir}/{run_name}_config-{config_counter}.npz"
                        
                        # Combine all data
                        save_dict = {**val_recon_dict, **cross_modal_generations}
                        
                        # Also save the latent representations for analysis
                        save_dict['shared_mu'] = shared_mu.cpu().numpy()
                        save_dict['shared_logvar'] = shared_logvar.cpu().numpy()
                        for i in range(len(specific_mus)):
                            save_dict[f'specific_mu_mod{i}'] = specific_mus[i].cpu().numpy()
                            save_dict[f'specific_logvar_mod{i}'] = specific_logvars[i].cpu().numpy()
                        
                        np.savez(output_file, **save_dict)
                        print(f"Saved cross-modal generations to {output_file}")
                        print(f"  - Validation reconstructions: {len(val_recon)} modalities")
                        print(f"  - Cross-modal generations: {len(cross_modal_generations)} pairs")
