import numpy as np
import torch
import pandas as pd
import os
import random
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.mm_parametric_simulation import TabularMMDataSimulator, HierarchicalTabularMMDataSimulator, GeometricTabularMMDataSimulator
from src.functions.train_larrp_multimodal import train_overcomplete_ae, compute_classification
from src.models.mm_baselines import CCA, JIVE, AJIVE, DIVAS, PPD, SLIDE, ShIndICA

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--n_samples', type=int, default=10000, help='number of samples to use for the computation')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--paired', action='store_true', help='whether the data is paired (default: False)')
parser.add_argument('--data_version', type=str, default='small', help='version of the data to use (default: small). can be small or large')
parser.add_argument('--threshold', type=str, default='absolute', choices=['relative', 'absolute'], help='whether to use relative or absolute R2 thresholds (default: relative)')
args = parser.parse_args()

# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

noise = 0.0
sparsity = 0.0

run_name = f"baselines_mm-parametric_n-{args.n_samples}_rseed-{args.seed}_paired-{args.paired}_v-{args.data_version}"
out_file = f"03_results/reports/{run_name}.csv"

# if the file already exists, remove it
if os.path.exists(out_file):
    print(f"File {out_file} already exists, removing it.")
    os.remove(out_file)

if args.data_version == 'small':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'small-c':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['gumbel', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'small-d':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['poisson', 'geometric'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '10-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-b':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-c':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['gumbel', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-d':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'geometric'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [2, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced3':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class-gaussian',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class-gaussian',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large-c':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['gumbel', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large-d':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'geometric'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
else:
    raise ValueError("data_version must be either 'small' or 'large'")

tab_sim = TabularMMDataSimulator(
    n_samples=data_hyperparams['n_samples'],
    n_shared_variables=data_hyperparams['n_shared_variables'],
    n_hidden_variables=data_hyperparams['n_hidden_variables'],
    shared_hidden_dist_type=data_hyperparams['shared_hidden_dist_type'],
    hidden_dist_types=data_hyperparams['hidden_dist_types'],
    data_dims=data_hyperparams['data_dims'],
    nonlinearity_level=data_hyperparams['nonlinearity_level'],
    nonlinearity_type=data_hyperparams['nonlinearity_type'],
    hidden_connectivities=data_hyperparams['hidden_connectivities'],
    data_sparsity=data_hyperparams['data_sparsity'],
    noise_variance=data_hyperparams['noise_variance'],
    random_seed=data_hyperparams['random_seed'],
)
print("Data generated. Computing initial analysis.")

y1, y2, x0, x1, x2, labels = tab_sim.generate_data()
if args.data_version == 'large':
    # scale down y1 and y2 by 100
    #y1 = y1 / 10.0
    #y2 = y2 / 10.0
    # normalize y1 and y2 to have values between 0 and 10 (I think we need either this or loss balancing because the second modality is larger magnitude and ends up having higher rank)
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))
else:
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))

# save a control image
n_cols = 3 + 2
n_rows = 4  # Increased from 3 to 4 to add histogram row
reps = [y1, y2, x0, x1, x2]

accs = []
sils = []
preds = []
for j in tqdm(range(len(reps))):
    acc = compute_classification(reps[j], labels[:, 0])
    
    # Check if silhouette score can be computed
    unique_labels = np.unique(labels[:, 0])
    n_unique = len(unique_labels)
    n_samples = len(labels[:, 0])
    
    if 2 <= n_unique <= n_samples - 1:
        sil = silhouette_score(reps[j], labels[:, 0])
    else:
        print(f"Cannot compute silhouette score for latent {j} with {n_unique} unique labels and {n_samples} samples.")
        # If too many unique labels, bin them into discrete classes
        if n_unique > n_samples // 2:  # Too many unique values, likely continuous
            # Bin into 10 discrete classes
            from sklearn.preprocessing import KBinsDiscretizer
            discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
            binned_labels = discretizer.fit_transform(labels[:, 0].reshape(-1, 1)).flatten().astype(int)
            sil = silhouette_score(reps[j], binned_labels)
        else:
            sil = np.nan  # Cannot compute silhouette score
    
    accs.append(acc)
    sils.append(sil)

    temp_preds = []
    for i in [1,2]:
        # initialize linear regression
        reg = LinearRegression()
        reg.fit(reps[j], labels[:, i])
        # compute the R**2 of the regression fit
        r2 = reg.score(reps[j], labels[:, i])
        temp_preds.append(r2)
    preds.append(temp_preds)

print("Initial analysis complete.")

fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
for i in range(n_cols):
    pca = PCA(n_components=2)
    pca_reps = pca.fit_transform(reps[i])
    
    # First 3 rows: PCA plots colored by labels
    for j in range(3):  # Changed from n_rows to 3
        axs[j, i].scatter(pca_reps[:, 0], pca_reps[:, 1], c=labels[:, j], cmap='viridis', alpha=0.5)
        if i < 2:
            title_text = f'Data {i}, Label {j}'
        else:
            title_text = f'Latent {i-2}, Label {j}'
        if j == 0:
            title_text += f', Acc ({accs[i]:.2f}), Sil ({sils[i]:.2f})'
        else:
            title_text += f', Pred ({preds[i][j-1]:.2f})'
        axs[j, i].set_title(title_text)
        axs[j, i].set_xlabel('PC1')
        axs[j, i].set_ylabel('PC2')
    
    # Fourth row: Histograms of value distributions
    axs[3, i].hist(reps[i].flatten(), bins=100, alpha=0.7, edgecolor='black')
    if i < 2:
        hist_title = f'Data {i} Value Distribution'
    else:
        hist_title = f'Latent {i-2} Value Distribution'
    axs[3, i].set_title(hist_title)
    axs[3, i].set_xlabel('Value')
    axs[3, i].set_ylabel('Count')
    axs[3, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"03_results/plots/{run_name}_control.png")
plt.close(fig)
del reps

# set it in the right format
if not args.paired:
    y1_pad = np.zeros_like(y1)
    y2_pad = np.zeros_like(y2)
    y1 = np.concatenate([y1, y1_pad], axis=0)
    y2 = np.concatenate([y2_pad, y2], axis=0)
    labels = np.concatenate([labels, labels], axis=0)

data = [torch.FloatTensor(y1), torch.FloatTensor(y2)]

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
random.seed(args.seed)

# Baseline methods to run
baseline_methods = [
    ('CCA', lambda: CCA(energy_threshold=0.8)),
    ('DIVAS', lambda: DIVAS()),
    ('JIVE', lambda: JIVE()),
    ('AJIVE', lambda: AJIVE()),
    ('PPD', lambda: PPD()),
    ('SLIDE', lambda: SLIDE()),
    ('ShIndICA', lambda: ShIndICA(joint_rank_options=[1,5,10,20])),
]

print(f"Number of baseline methods: {len(baseline_methods)}")

###
# start running baseline methods
###
config_counter = 0
for method_name, method_constructor in baseline_methods:
    config_counter += 1
    print(f"### Running {method_name} ({config_counter}/{len(baseline_methods)}) ###")
    
    # Create method instance
    method = method_constructor()
    
    # Move data to CPU for baselines (they use PyTorch but may not need GPU)
    data_cpu = [d.cpu() for d in data]
    print(data_cpu[0].shape, data_cpu[1].shape)
    
    # Decompose the data
    decomposed_reps, rank_info = method.decompose(data_cpu)
    reps = [decomposed_reps[m] for m in sorted(decomposed_reps.keys())]
    subspace_names = list(decomposed_reps.keys())
    #if method_name == "CCA":
    #    reps = [reps[1]]  # third one are correlations, not a representation
    #    subspace_names = [subspace_names[1]]
    
    print(decomposed_reps.keys())
    if type(reps) is list:
        print([reps[i].shape for i in range(len(reps))])
    print(rank_info)

    temp_df = pd.DataFrame()
    temp_df["method"] = [method_name]
    temp_df["joint_rank"] = [rank_info.get('joint_rank', np.nan)]
    temp_df["individual_rank_X"] = [rank_info.get('individual_rank_X', np.nan)]
    temp_df["individual_rank_Y"] = [rank_info.get('individual_rank_Y', np.nan)]
    temp_df["subspace_names"] = [', '.join(subspace_names)]

    # calculate classification accuracy and silhouette score on the latents for label 0
    accs = []
    sils = []
    for j in range(len(reps)):
        # Convert to numpy if needed
        rep_np = reps[j].cpu().numpy() if torch.is_tensor(reps[j]) else reps[j]
        acc = compute_classification(rep_np, labels[:, 0])
        
        # Check if silhouette score can be computed
        unique_labels = np.unique(labels[:, 0])
        n_unique = len(unique_labels)
        n_samples = len(labels[:, 0])
        
        if 2 <= n_unique <= n_samples - 1:
            sil = silhouette_score(rep_np, labels[:, 0])
        else:
            # If too many unique labels, bin them into discrete classes
            if n_unique > n_samples // 2:  # Too many unique values, likely continuous
                # Bin into 10 discrete classes
                from sklearn.preprocessing import KBinsDiscretizer
                discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
                binned_labels = discretizer.fit_transform(labels[:, 0].reshape(-1, 1)).flatten().astype(int)
                sil = silhouette_score(rep_np, binned_labels)
            else:
                sil = np.nan  # Cannot compute silhouette score
        
        accs.append(acc)
        sils.append(sil)
    temp_df["classification_accuracy"] = ', '.join([str(a) for a in accs])
    temp_df["silhouette_score"] = ', '.join([str(s) for s in sils])
    print(f"Classification: {[str(a) for a in accs]}")

    # predictability of labels 1 and 2 from each latent
    # these labels are the means of hidden modality 1 and 2, so we need regression
    preds = []
    for j in [1,2]:
        # initialize linear regression
        temp_preds = []
        for i in range(len(reps)):
            rep_np = reps[i].cpu().numpy() if torch.is_tensor(reps[i]) else reps[i]
            reg = LinearRegression()
            reg.fit(rep_np, labels[:, j])
            # compute the R**2 of the regression fit
            r2 = reg.score(rep_np, labels[:, j])
            temp_preds.append(r2)
        preds.append(temp_preds)
        temp_df[f"label_{j}_pred"] = ', '.join([str(p) for p in temp_preds])
        print(f"Label {j} Prediction: {[str(p) for p in temp_preds]}")

    # if out_file exists, append to it, otherwise create it
    if os.path.exists(out_file):
        temp_df.to_csv(out_file, mode='a', header=False, index=False)
    else:
        temp_df.to_csv(out_file, mode='w', header=True, index=False)
    
    try:
        # create PCA plots of the reps colored by the labels and save them
        n_cols = len(reps)
        n_rows = labels.shape[1]
        fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
        for i in range(n_cols):
            rep_np = reps[i].cpu().numpy() if torch.is_tensor(reps[i]) else reps[i]
            if rep_np.shape[1] > 2:
                pca = PCA(n_components=2)
                pca_reps = pca.fit_transform(rep_np)
            elif rep_np.shape[1] == 2:
                pca_reps = rep_np
            else:
                pca_reps = np.zeros((rep_np.shape[0], 2))  # handle edge case with single sample
                pca_reps[:, 0] = rep_np.flatten()
            for j in range(n_rows):
                if n_rows == 1:
                    ax = axs[i]
                elif n_cols == 1:
                    ax = axs[j]
                else:
                    ax = axs[j, i]
                ax.scatter(pca_reps[:, 0], pca_reps[:, 1], c=labels[:, j], cmap='viridis', alpha=0.5)
                if j == 0:
                    ax.set_title(f'{method_name} Mod {i}, Label {j}, Acc ({accs[i]:.2f}), Sil ({sils[i]:.2f})')
                else:
                    ax.set_title(f'{method_name} Mod {i}, Label {j}, Pred ({preds[j-1][i]:.2f})')
                ax.set_xlabel('PC1')
                ax.set_ylabel('PC2')
        plt.tight_layout()
        plt.savefig(f"03_results/plots/{run_name}_{method_name}_config{config_counter}.png")
        plt.close(fig)
    except Exception as e:
        print(f"Failed to create plot for {method_name}: {e}")