import numpy as np
import torch
import pandas as pd
import os
import random
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score
from sklearn.linear_model import LinearRegression
from tqdm import tqdm
# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.data.mm_parametric_simulation import TabularMMDataSimulator
from src.functions.train_larrp_multimodal import train_overcomplete_ae, compute_classification
from src.models.larrp_multimodal_cnn import MMSimData

import argparse
parser = argparse.ArgumentParser(description='Compute basic ID estimation metrics')
parser.add_argument('--n_samples', type=int, default=10000, help='number of samples to use for the computation')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--paired', action='store_true', help='whether the data is paired (default: False)')
parser.add_argument('--data_version', type=str, default='small', help='version of the data to use (default: small). can be small or large')
parser.add_argument('--threshold', type=str, default='absolute', choices=['relative', 'absolute'], help='whether to use relative or absolute R2 thresholds (default: relative)')
parser.add_argument('--force_retrain', action='store_true', help='Force retraining even if checkpoints exist')
args = parser.parse_args()

# set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

noise = 0.0
sparsity = 0.0

run_name = f"larrp_mm-parametric-sim4_n-{args.n_samples}_rseed-{args.seed}_paired-{args.paired}_v-{args.data_version}"
out_file = f"03_results/reports/{run_name}.csv"

# if the file already exists, remove it
if os.path.exists(out_file):
    print(f"File {out_file} already exists, removing it.")
    os.remove(out_file)

if args.data_version == 'small':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [3, 5],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '10-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced-b':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [2, 2],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'imbalanced2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 2,
        'shared_hidden_dist_type': 'binomial',
        'n_hidden_variables': [2, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class-gaussian',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
elif args.data_version == 'large2':
    data_hyperparams = {
        'n_samples': args.n_samples,
        'n_shared_variables': 20,
        'shared_hidden_dist_type': '20-class',
        'n_hidden_variables': [20, 20],
        'hidden_dist_types': ['poisson', 'weibull'],
        'data_dims': [200, 200],
        'nonlinearity_level': 0,
        'nonlinearity_type': 'polynomial',
        'hidden_connectivities': [0.5, 0.5, 0.5],
        'data_sparsity': sparsity,
        'noise_variance': noise,
        'random_seed': 42
    }
else:
    raise ValueError("data_version must be either 'small' or 'large'")

tab_sim = TabularMMDataSimulator(
    n_samples=data_hyperparams['n_samples'],
    n_shared_variables=data_hyperparams['n_shared_variables'],
    n_hidden_variables=data_hyperparams['n_hidden_variables'],
    shared_hidden_dist_type=data_hyperparams['shared_hidden_dist_type'],
    hidden_dist_types=data_hyperparams['hidden_dist_types'],
    data_dims=data_hyperparams['data_dims'],
    nonlinearity_level=data_hyperparams['nonlinearity_level'],
    nonlinearity_type=data_hyperparams['nonlinearity_type'],
    hidden_connectivities=data_hyperparams['hidden_connectivities'],
    data_sparsity=data_hyperparams['data_sparsity'],
    noise_variance=data_hyperparams['noise_variance'],
    random_seed=data_hyperparams['random_seed'],
)
print("Data generated. Computing initial analysis.")

y1, y2, x0, x1, x2, labels = tab_sim.generate_data()
if args.data_version == 'large':
    # scale down y1 and y2 by 100
    #y1 = y1 / 10.0
    #y2 = y2 / 10.0
    # normalize y1 and y2 to have values between 0 and 10 (I think we need either this or loss balancing because the second modality is larger magnitude and ends up having higher rank)
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))
else:
    y1 = 10 * (y1 - np.min(y1)) / (np.max(y1) - np.min(y1))
    y2 = 10 * (y2 - np.min(y2)) / (np.max(y2) - np.min(y2))

# save a control image
n_cols = 3 + 2
n_rows = 4  # Increased from 3 to 4 to add histogram row
reps = [y1, y2, x0, x1, x2]

accs = []
sils = []
preds = []
for j in tqdm(range(len(reps))):
    if args.data_version != 'large2-separate':
        acc = compute_classification(reps[j], labels[:, 0])
    else:
        acc = np.nan
    
    # Check if silhouette score can be computed
    unique_labels = np.unique(labels[:, 0])
    n_unique = len(unique_labels)
    n_samples = len(labels[:, 0])
    
    if 2 <= n_unique <= n_samples - 1:
        sil = silhouette_score(reps[j], labels[:, 0])
    else:
        print(f"Cannot compute silhouette score for latent {j} with {n_unique} unique labels and {n_samples} samples.")
        # If too many unique labels, bin them into discrete classes
        if n_unique > n_samples // 2:  # Too many unique values, likely continuous
            # Bin into 10 discrete classes
            from sklearn.preprocessing import KBinsDiscretizer
            discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
            binned_labels = discretizer.fit_transform(labels[:, 0].reshape(-1, 1)).flatten().astype(int)
            sil = silhouette_score(reps[j], binned_labels)
        else:
            sil = np.nan  # Cannot compute silhouette score
    
    accs.append(acc)
    sils.append(sil)

    temp_preds = []
    if (args.data_version != 'large2-separate') and (not j == 0):
        for i in [1,2]:
            # initialize linear regression
            reg = LinearRegression()
            reg.fit(reps[j], labels[:, i])
            # compute the R**2 of the regression fit
            r2 = reg.score(reps[j], labels[:, i])
            temp_preds.append(r2)
    else:
        for i in [1,2]:
            temp_preds.append(np.nan)
    preds.append(temp_preds)

print("Initial analysis complete.")

fig, axs = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
for i in range(n_cols):
    if reps[i].shape[1] == 0:
        pca_reps = np.zeros((reps[i].shape[0], 2))
    else:
        pca = PCA(n_components=2)
        pca_reps = pca.fit_transform(reps[i])
    
    # First 3 rows: PCA plots colored by labels
    for j in range(3):  # Changed from n_rows to 3
        axs[j, i].scatter(pca_reps[:, 0], pca_reps[:, 1], c=labels[:, j], cmap='viridis', alpha=0.5)
        if i < 2:
            title_text = f'Data {i}, Label {j}'
        else:
            title_text = f'Latent {i-2}, Label {j}'
        if j == 0:
            title_text += f', Acc ({accs[i]:.2f}), Sil ({sils[i]:.2f})'
        else:
            title_text += f', Pred ({preds[i][j-1]:.2f})'
        axs[j, i].set_title(title_text)
        axs[j, i].set_xlabel('PC1')
        axs[j, i].set_ylabel('PC2')
    
    # Fourth row: Histograms of value distributions
    axs[3, i].hist(reps[i].flatten(), bins=100, alpha=0.7, edgecolor='black')
    if i < 2:
        hist_title = f'Data {i} Value Distribution'
    else:
        hist_title = f'Latent {i-2} Value Distribution'
    axs[3, i].set_title(hist_title)
    axs[3, i].set_xlabel('Value')
    axs[3, i].set_ylabel('Count')
    axs[3, i].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"03_results/plots/{run_name}_control.png")
plt.close(fig)
del reps

# set it in the right format
if not args.paired:
    y1_pad = np.zeros_like(y1)
    y2_pad = np.zeros_like(y2)
    y1 = np.concatenate([y1, y1_pad], axis=0)
    y2 = np.concatenate([y2_pad, y2], axis=0)
    labels = np.concatenate([labels, labels], axis=0)

data = [torch.FloatTensor(y1), torch.FloatTensor(y2)]

################################
# post-training function

def train_overcomplete_ae_continued(data, n_samples_train, model, device, epochs=100, 
                         lr=0.001, batch_size=128, wd=1e-5, 
                         verbose=True, recon_loss_balancing=False,
                         ):
    """
    Train an autoencoder with adaptive rank reduction
    """
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    loss_fn = torch.nn.MSELoss()
    
    # Create data loader
    # careful with the non-paired data because of how it is concatenated
    # first randomize the rows
    data_indices = torch.randperm(data[0].shape[0])
    train_indices = data_indices[:n_samples_train]
    val_indices = data_indices[n_samples_train:]
    train_data = [d[train_indices] for d in data]  # Randomize rows
    train_data = MMSimData(train_data)
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_data = [data[i][val_indices] for i in range(len(data))]  # Split data into validation set
    val_data = MMSimData(val_data)
    val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    n_samples = data[0].shape[0]
    n_samples_val = n_samples - n_samples_train
    
    # Train the model
    train_losses = []
    val_losses = []
    r_squares = []
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(len(data), device=device)

    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(data)
        ema_decay = 0.9

    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(data)
        
        for batch_idx, (x, mask) in enumerate(data_loader):
            mask = None
            # if mask is nan, set to None
            if isinstance(mask, torch.Tensor) and torch.isnan(mask).all():
                mask = None
            ### plotting test
            # Store last batch for plotting
            last_batch_data = [x_m.clone() for x_m in x]
            # Get labels if they exist in the dataset
            if hasattr(train_data, 'labels') and train_data.labels is not None:
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(train_data.labels))
                last_batch_labels = train_data.labels[start_idx:end_idx].clone()
            else:
                last_batch_labels = None
            ###
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device) for x_m in x]
            
            # Forward pass
            x_hat, h_list = model(x)

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Extract masks for each modality
            modality_masks = []
            if mask is not None:
                start_idx = 0
                for i, x_m in enumerate(x):
                    end_idx = start_idx + x_m.shape[1]
                    modality_masks.append(mask[:, start_idx:end_idx])
                    start_idx = end_idx
                # sanity check: see how many samples are masked
                #print([f"Modality {i} data shape: {x_m.shape}, mask shape: {mask_i.shape}, fraction unmasked: {mask_i.sum() / mask_i.numel()}" for i, mask_i in enumerate(modality_masks)])
            else:
                modality_masks = [None] * len(x)
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                if modality_masks[i] is not None:
                    m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                else:
                    m_loss = F.mse_loss(x_hat_m, x_m)
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            total_loss.backward()
            
            # Apply gradient clipping
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for x_val, mask in val_data_loader:
                x_val = [x_m.to(device) for x_m in x_val]
                x_val_hat, _ = model(x_val)

                modality_masks = []
                if mask is not None:
                    start_idx = 0
                    for i, x_m in enumerate(x_val):
                        end_idx = start_idx + x_m.shape[1]
                        modality_masks.append(mask[:, start_idx:end_idx])
                        start_idx = end_idx
                else:
                    modality_masks = [None] * len(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    if modality_masks[i] is not None:
                        m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                    else:
                        m_loss = F.mse_loss(x_hat_m, x_m)
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)

        log_dict = {
            'loss': round(train_loss, 4),
            'mod_losses': [round(l, 3) for l in per_modality_losses],
        }
        pbar.set_postfix(log_dict)
        
    # Calculate latent representations in batches
    #'''
    n_samples = data[0].shape[0]
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    reps = [torch.empty((n_samples, final_ranks[i]), device=device) for i in range(len(final_ranks))]
    model.eval()
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            x_batch = [data[j][i:end_idx].to(device) for j in range(len(data))]
            
            batch_reps = model.encode(x_batch)#.cpu()
            batch_rep_list = [batch_reps[0]] + [batch_reps[1][j] for j in range(len(batch_reps[1]))]
                
            # No need to convert dtype
            for j in range(len(reps)):
                reps[j][i:end_idx,:] = batch_rep_list[j][:,:final_ranks[j]].cpu()
            
            # Free memory
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

    return model, reps

################################

np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
random.seed(args.seed)

# method parameters
method_hyperparameters = {
    "r_square_thresholds": [0.05],
    "early_stopping": [50],
    "rank_reduction_frequencies": [10],
    "rank_reduction_thresholds": [0.01],
    "patiences": [10],
}
from itertools import product
method_combinations = list(product(*method_hyperparameters.values()))
print(f"Number of method combinations: {len(method_combinations)}")

class Args:
    def __init__(self):
        # latent
        self.latent_dim = 100

        # Training parameters
        self.batch_size = 128
        self.lr = 0.0001
        self.weight_decay = 1e-5
        self.dropout = 0.1
        self.epochs = 5000
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 1
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu
train_args = Args()

###
# start training multiple configs
###
config_counter = 0

# Extract single parameter set
r_square_threshold = method_hyperparameters["r_square_thresholds"][0]
early_stopping = method_hyperparameters["early_stopping"][0]
rank_reduction_frequency = method_hyperparameters["rank_reduction_frequencies"][0]
rank_reduction_threshold = method_hyperparameters["rank_reduction_thresholds"][0]
patience = method_hyperparameters["patiences"][0]

config_counter += 1
print(f"### Run {config_counter}/{len(method_combinations)} ###")

# Create checkpoints directory
checkpoint_dir = Path("03_results/checkpoints")
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_dir / f"{run_name}_config-{config_counter}.pt"

if checkpoint_path.exists() and not args.force_retrain:
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    reps = checkpoint['reps']
    rank_history = checkpoint['rank_history']
    loss_curves = checkpoint['loss_curves']
    r_squares = checkpoint['r_squares']
    train_loss = checkpoint.get('train_loss', None)
    model = None  # Model object not reconstructed here
else:
    model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae(
        data, 
        int(0.9 * args.n_samples),
        train_args.latent_dim, 
        DEVICE,
        train_args,
        epochs=train_args.epochs, 
        lr=train_args.lr, 
        batch_size=train_args.batch_size, 
        ae_depth=train_args.ae_depth, 
        ae_width=train_args.ae_width, 
        dropout=train_args.dropout, 
        wd=train_args.weight_decay,
        early_stopping=early_stopping,
        initial_rank_ratio=1.0,
        rank_reduction_frequency=rank_reduction_frequency,
        rank_reduction_threshold=rank_reduction_threshold,
        warmup_epochs=early_stopping,
        patience=patience,
        min_rank=1,
        #min_rank=0,
        r_square_threshold=r_square_threshold,
        threshold_type=args.threshold,
        compressibility_type='direct',
        verbose=False,
        #compute_jacobian=True,  # compute the contractive loss
        compute_jacobian=False,
        #include_ortholoss=True,
        #ortho_loss_anneal_epochs=1000,
        #ortho_loss_end_weight=0.1,
        #ortho_loss_warmup=100,
        #l2_norm_adaptivelayers=1e-3,
        sharedwhenall=False,
        model_name=run_name + str(config_counter)
    )

    print("Post-training")
    model, reps = train_overcomplete_ae_continued(data, int(0.9 * args.n_samples), model, DEVICE, epochs=1000, lr=train_args.lr*0.1, batch_size=train_args.batch_size, wd=train_args.weight_decay, verbose=False)

    # Save checkpoint
    torch.save({
        'model_state_dict': model.state_dict(),
        'reps': reps,
        'labels': labels,
        'rank_history': rank_history,
        'loss_curves': loss_curves,
        'r_squares': r_squares,
        'train_loss': train_loss
    }, checkpoint_path)

# print the length of each list in rank_history
#print(f"Rank history lengths: {[(key, len(rank_history[key])) for key in rank_history.keys()]}")

temp_df = pd.DataFrame(rank_history)
#temp_df["train_loss"] = [loss_curves[0][e] for e in temp_df["epoch"]]
#temp_df["val_loss"] = [loss_curves[1][e] for e in temp_df["epoch"]]

# add all the data and method parameters to the dataframe
temp_df["r_square_threshold"] = r_square_threshold
temp_df["early_stopping"] = early_stopping
temp_df["rank_reduction_frequency"] = rank_reduction_frequency
temp_df["rank_reduction_threshold"] = rank_reduction_threshold
temp_df["patience"] = patience

# add the most important final metrics
valid_rsquares = [rank_history[f"rsquare {j}"] for j in range(len(data))]
temp_df["r_squares_init"] = ', '.join([str(valid_rsquares[0][0]), str(valid_rsquares[1][0])])
temp_df["r_squares_final"] = ', '.join([str(valid_rsquares[0][-1]), str(valid_rsquares[1][-1])])
temp_df["final_ranks"] = rank_history["ranks"][-1]
temp_df["config"] = config_counter

# calculate classification accuracy and silhouette score on the latents for label 0
accs = []
sils = []
for j in range(len(reps)):
    if args.data_version != 'large2-separate':
        acc = compute_classification(reps[j].cpu().numpy(), labels[:, 0])
    else:
        acc = np.nan

    # Check if silhouette score can be computed
    unique_labels = np.unique(labels[:, 0])
    n_unique = len(unique_labels)
    n_samples = len(labels[:, 0])
    
    if 2 <= n_unique <= n_samples - 1:
        sil = silhouette_score(reps[j].cpu().numpy(), labels[:, 0])
    else:
        # If too many unique labels, bin them into discrete classes
        if n_unique > n_samples // 2:  # Too many unique values, likely continuous
            # Bin into 10 discrete classes
            from sklearn.preprocessing import KBinsDiscretizer
            discretizer = KBinsDiscretizer(n_bins=min(10, n_samples//10), encode='ordinal', strategy='uniform')
            binned_labels = discretizer.fit_transform(labels[:, 0].reshape(-1, 1)).flatten().astype(int)
            sil = silhouette_score(reps[j].cpu().numpy(), binned_labels)
        else:
            sil = np.nan  # Cannot compute silhouette score
    
    accs.append(acc)
    sils.append(sil)
temp_df["classification_accuracy"] = ', '.join([str(a) for a in accs])
temp_df["silhouette_score"] = ', '.join([str(s) for s in sils])
print(f"Classification: {[str(a) for a in accs]}")

# predictability of labels 1 and 2 from each latent
# these labels are the means of hidden modality 1 and 2, so we need regression
preds = []
for j in [1,2]:
    # initialize linear regression
    temp_preds = []
    for i in range(len(reps)):
        reg = LinearRegression()
        reg.fit(reps[i].cpu().numpy(), labels[:, j])
        # compute the R**2 of the regression fit
        r2 = reg.score(reps[i].cpu().numpy(), labels[:, j])
        temp_preds.append(r2)
    preds.append(temp_preds)
    temp_df[f"label_{j}_pred"] = ', '.join([str(p) for p in temp_preds])
    print(f"Label {j} Prediction: {[str(p) for p in temp_preds]}")

# if out_file exists, append to it, otherwise create it
if os.path.exists(out_file):
    temp_df.to_csv(out_file, mode='a', header=False, index=False)
else:
    temp_df.to_csv(out_file, mode='w', header=True, index=False)

try:
    # create PCA plots of the reps colored by the labels and save them
    n_cols = len(reps)
    n_rows = labels.shape[1]
    
    # Use a cleaner style for publication
    # plt.style.use('seaborn-whitegrid') # Optional, depends on available styles
    # set style parameters for better visibility
    plt.rcParams.update({
        'axes.titlesize': 10,
        'axes.titleweight': 'bold',
        'axes.labelsize': 8,
        'xtick.labelsize': 6,
        'ytick.labelsize': 6,
        'figure.dpi': 300,
        'figure.figsize': (n_cols * 2, n_rows * 2),
        'legend.fontsize': 6,
        'lines.markersize': 4,
    })
    
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2), dpi=300)
    
    # Define colormaps: 0=Categorical, 1=Continuous (e.g., viridis), 2=Continuous (e.g., plasma)
    # Determine appropriate colormap for categorical label 0
    unique_labels = np.unique(labels[:, 0])
    n_classes = len(unique_labels)
    categorical_cmap = 'tab10'
    if n_classes > 10:
        categorical_cmap = 'tab20'
    if n_classes > 20:
        categorical_cmap = 'nipy_spectral' # Fallback for many classes

    cmaps = [categorical_cmap, 'viridis_r', 'plasma_r']
    
    for i in range(n_cols):
        if reps[i].shape[1] > 2:
            pca = PCA(n_components=2)
            pca_reps = pca.fit_transform(reps[i].cpu().numpy())
        elif reps[i].shape[1] == 2:
            pca_reps = reps[i].cpu().numpy()
        else:
            pca_reps = np.zeros((reps[i].shape[0], 2))  # handle edge case with single sample
            pca_reps[:, 0] = reps[i].cpu().numpy().flatten()
            
        for j in range(n_rows):
            # Select colormap
            current_cmap = cmaps[j] if j < len(cmaps) else 'viridis_r'
            
            # Scatter plot
            sc = axs[j, i].scatter(
                pca_reps[:, 0], 
                pca_reps[:, 1], 
                c=labels[:, j], 
                cmap=current_cmap, 
                alpha=0.6, 
                s=10, 
                edgecolors='none'
            )
            
            # Titles with metrics
            if j == 0:
                title = f'Representation {i}\nAcc: {accs[i]:.2f}, Sil: {sils[i]:.2f}'
            else:
                title = f'Representation {i}\nLabel {j} R²: {preds[j-1][i]:.2f}'
            
            axs[j, i].set_title(title, fontsize=10, fontweight='bold')
            
            # Remove Axis labels and ticks
            axs[j, i].set_xlabel('')
            axs[j, i].set_ylabel('')
            axs[j, i].set_xticks([])
            axs[j, i].set_yticks([])
            
            # Box on (all spines visible)
            axs[j, i].spines['top'].set_visible(True)
            axs[j, i].spines['right'].set_visible(True)
            axs[j, i].spines['bottom'].set_visible(True)
            axs[j, i].spines['left'].set_visible(True)
            
            # Add colorbar for continuous variables (labels 1 and 2), only on the rightmost column
            if j > 0 and i == n_cols - 1:
                # Add a colorbar to the right of the plot
                cbar = plt.colorbar(sc, ax=axs[j, i])
                cbar.ax.tick_params(labelsize=6)

    plt.tight_layout()
    save_path = f"03_results/plots/{run_name}_config-{config_counter}_publication.png"
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    print(f"Saved publication figure to {save_path}")
    plt.close(fig)
except Exception as e:
    print(f"Error in plotting PCA of latents: {e}")

