import time

import lightning
import numpy as np
import pandas as pd
import torch
from scipy.stats import spearmanr
from torch import Tensor
from tqdm import tqdm

from data import EasyTPPDataModule


@torch.no_grad()
def generate_samples(
    datamodule: EasyTPPDataModule,
    module: lightning.LightningModule,
    **sample_kwargs,
) -> None:
    time_old = []
    time_new = []

    x_old_sample, t_old_sample = [], []
    x_new_sample, t_new_sample = [], []
    mark_consts, time_consts = [], []

    all_inds = []

    log_prob_old = []
    log_prob_new = []

    time_profile_old = []
    time_profile_new = []

    def get_log_prob(module: lightning.LightningModule, x: Tensor, t: Tensor) -> Tensor:
        pm, pt = module.decoder(module.encoder(x, t))
        # t = t.clamp(1e-8)
        log_prob = pm.log_prob(x) + pt.log_prob(t)
        return log_prob

    for i, batch in enumerate(tqdm(datamodule.test_dataloader())):
        x, t, _ = batch
        x = x.to(module.device)
        t = t.to(module.device)

        # Old method samples
        start_time = time.time()
        x_, t_, _times = module.sample(
            x, t, time_profile=True, **sample_kwargs,
        )
        time_profile_old.append(_times)
        time_old.append(time.time() - start_time)

        log_prob_old.append(get_log_prob(module, x_, t_).detach().cpu())
        x_old_sample.append(x_.detach().cpu())
        t_old_sample.append(t_.detach().cpu())

        # New method samples
        start_time = time.time()
        x_, t_, ind, _times, _mark_consts, _time_consts = module.rejection_sample(
            x, t, time_profile=True, **sample_kwargs,
        )
        time_profile_new.append(_times)
        mark_consts.append(torch.cat(_mark_consts))
        time_consts.append(torch.cat(_time_consts))
        time_new.append(time.time() - start_time)
        log_prob_new.append(get_log_prob(module, x_, t_).detach().cpu())
        x_new_sample.append(x_.detach().cpu())
        t_new_sample.append(t_.detach().cpu())

        all_inds.append(ind)

    print(f'Old method time: {np.mean(time_old):.4f} ± {np.std(time_old):.4f}')
    print(f'New method time: {np.mean(time_new):.4f} ± {np.std(time_new):.4f}')

    x_old_sample = torch.stack(x_old_sample, dim=0)
    t_old_sample = torch.stack(t_old_sample, dim=0)
    x_new_sample = torch.stack(x_new_sample, dim=0)
    t_new_sample = torch.stack(t_new_sample, dim=0)
    log_prob_old = torch.stack(log_prob_old, dim=0)
    log_prob_new = torch.stack(log_prob_new, dim=0)
    all_inds = torch.cat([torch.stack(x) for x in all_inds]).detach().cpu()

    return {
        'x_old_sample': x_old_sample,
        't_old_sample': t_old_sample,
        'x_new_sample': x_new_sample,
        't_new_sample': t_new_sample,
        'all_inds': all_inds,
        'log_prob_old': log_prob_old,
        'log_prob_new': log_prob_new,
        'time_old': time_old,
        'time_new': time_new,
        'time_profile_old': time_profile_old,
        'time_profile_new': time_profile_new,
        'mark_consts': mark_consts,
        'time_consts': time_consts,
    }



def kl_divergence(x_old_sample: Tensor, x_new_sample: Tensor, dim: int) -> dict[str, float]:
    # Convert to NumPy arrays
    x_old_sample = x_old_sample.detach().cpu().numpy()
    x_new_sample = x_new_sample.detach().cpu().numpy()

    # Ensure tensor dimensions match
    assert x_old_sample.shape == x_new_sample.shape, "Tensor shapes don't match"

    # Shape: (B, S, L)
    B, S, L = x_old_sample.shape
    half = S // 2

    # Small epsilon to avoid log(0) or division by zero
    epsilon = 1e-8

    # --- Compute all metrics ---
    results = {}

    # For fair comparison, we'll use the first half of samples for both true and baseline metrics
    x_old_first_half = x_old_sample[:, :half, :]
    x_new_first_half = x_new_sample[:, :half, :]
    x_old_second_half = x_old_sample[:, half:, :]

    # Make sure splits are equal size
    assert x_old_first_half.shape == x_old_second_half.shape, "Half splits aren't equal"
    assert x_old_first_half.shape == x_new_first_half.shape, "Old and new sample splits aren't equal"

    # Compute per-item KL
    kl_per_item = _compute_per_item_kl(x_old_first_half, x_new_first_half, dim, epsilon)
    results["KL Avg Per Item"] = kl_per_item.mean()
    results["KL Avg Per Item Std"] = kl_per_item.std()

    # Compute baseline KL
    kl_baseline_per_item = _compute_per_item_kl(x_old_first_half, x_old_second_half, dim, epsilon)
    results["KL Avg Per Item Baseline"] = kl_baseline_per_item.mean()
    results["KL Avg Per Item Baseline Std"] = kl_baseline_per_item.std()

    return results


def _compute_per_item_kl(x_p: Tensor, x_q: Tensor, dim: int, epsilon: int) -> Tensor:
    """Compute KL divergence per item."""
    B, S, L = x_p.shape
    kl_values = np.zeros((B, L))

    for b in range(B):
        for l in range(L):
            # Compute counts for this item
            counts_p = np.bincount(x_p[b, :, l], minlength=dim)
            counts_q = np.bincount(x_q[b, :, l], minlength=dim)

            # Compute probabilities
            prob_p = counts_p / counts_p.sum()
            prob_q = counts_q / counts_q.sum()

            # Add epsilon for numerical stability
            prob_p = np.clip(prob_p, epsilon, 1)
            prob_q = np.clip(prob_q, epsilon, 1)

            # Compute KL divergence
            kl_values[b, l] = np.sum(prob_p * np.log(prob_p / prob_q))

    return kl_values


def _median_l1_bandwidth(t_old_sample: Tensor) -> float:
    B, S, L = t_old_sample.shape
    half = S // 2
    first_half = t_old_sample[:, :half, :]  # Shape: (B, S/2, L)
    second_half = t_old_sample[:, half:, :]  # Shape: (B, S/2, L)
    l1_distances = np.abs(first_half - second_half)  # Shape: (B, S/2, L)
    return np.median(l1_distances)


def compute_mmd_with_variance(t_old_sample: Tensor, t_new_sample: Tensor) -> dict[str, float]:
    """
    Compute MMD metric and return single numbers for MMD mean, variance, and baseline MMD.

    Args:
        t_old_sample (torch.Tensor): Old samples of shape (B, S, L).
        t_new_sample (torch.Tensor): New samples of shape (B, S, L).

    Returns:
        dict: Contains:
            - "MMD Mean": Average MMD across all items and sequences (single number).
            - "MMD Std": Standard deviation of MMD across items and sequences.
            - "Baseline MMD": MMD between two halves of t_old_sample (single number).
    """
    # Convert to NumPy arrays
    t_old_sample = t_old_sample.detach().cpu().numpy()
    t_new_sample = t_new_sample.detach().cpu().numpy()

    # Shape: (B, S, L)
    B, S, L = t_old_sample.shape
    half = S // 2

    # Compute bandwidth using median L1 distance
    kernel_bandwidth = _median_l1_bandwidth(t_old_sample)

    # Gaussian kernel function
    def gaussian_kernel(x: np.ndarray, y: np.ndarray, bandwidth: float) -> np.ndarray:
        """
        Compute the Gaussian kernel matrix between two sets of samples.
        Args:
            x (np.ndarray): Samples of shape (S,).
            y (np.ndarray): Samples of shape (S,).
            bandwidth (float): Bandwidth for the Gaussian kernel.
        Returns:
            np.ndarray: Kernel matrix of shape (S, S).
        """
        x = x[:, np.newaxis]  # Reshape to (S, 1) for broadcasting
        y = y[np.newaxis, :]  # Reshape to (1, S) for broadcasting
        return np.exp(-((x - y) ** 2) / (2 * bandwidth ** 2))

    # Compute MMD for each item in the sequence (averaging over S)
    mmd_per_item = np.zeros((B, L))
    for b in range(B):
        for l in range(L):
            # Extract samples for this sequence and item
            t_old = t_old_sample[b, :half, l]  # Shape: (S,)
            t_new = t_new_sample[b, :half, l]  # Shape: (S,)

            # Compute pairwise kernel matrices
            k_xx = gaussian_kernel(t_old, t_old, kernel_bandwidth).mean()
            k_yy = gaussian_kernel(t_new, t_new, kernel_bandwidth).mean()
            k_xy = gaussian_kernel(t_old, t_new, kernel_bandwidth).mean()

            # Compute MMD for this item
            mmd_per_item[b, l] = k_xx + k_yy - 2 * k_xy

    # Compute mean and standard deviation of MMD across all items and sequences
    mmd_mean = mmd_per_item.mean()  # Average over B and L
    mmd_std = mmd_per_item.std()    # Standard deviation over B and L

    # --- Compute Baseline MMD (Within t_old_sample) ---
    baseline_mmd_per_item = np.zeros((B, L))
    for b in range(B):
        for l in range(L):
            # Split samples into two halves
            t_old_first_half = t_old_sample[b, :half, l]  # Shape: (S/2,)
            t_old_second_half = t_old_sample[b, half:, l]  # Shape: (S/2,)

            # Compute pairwise kernel matrices
            k_xx = gaussian_kernel(t_old_first_half, t_old_first_half, kernel_bandwidth).mean()
            k_yy = gaussian_kernel(t_old_second_half, t_old_second_half, kernel_bandwidth).mean()
            k_xy = gaussian_kernel(t_old_first_half, t_old_second_half, kernel_bandwidth).mean()

            # Compute Baseline MMD for this item
            baseline_mmd_per_item[b, l] = k_xx + k_yy - 2 * k_xy

    # Compute mean baseline MMD across all items and sequences
    baseline_mmd = baseline_mmd_per_item.mean()
    baseline_mmd_std = baseline_mmd_per_item.std()

    return {
        "MMD Mean": mmd_mean,
        "MMD Std": mmd_std,
        "Baseline MMD": baseline_mmd,
        "Baseline MMD Std": baseline_mmd_std,
    }


def chi_squared_divergence(
    x_old_sample: Tensor, x_new_sample: Tensor, dim: int
) -> dict[str, float]:
    """
    Compute Chi-Squared divergence for categorical distributions.

    Args:
        x_old_sample (torch.Tensor): Old samples of shape (B, S, L).
        x_new_sample (torch.Tensor): New samples of shape (B, S, L).
        dim (int): Number of categories (0 to dim-1).

    Returns:
        dict: Contains:
            - "Chi-Squared Marginal": Overall Chi-Squared divergence on the marginal distribution (single number).
            - "Chi-Squared Avg Per Item": Average Chi-Squared divergence across all items and sequences.
            - "Chi-Squared Avg Per Item Std": Standard deviation of Chi-Squared divergence.
            - "Chi-Squared Baseline": Baseline Chi-Squared divergence between two halves of x_old_sample.
            - "Chi-Squared Baseline Std": Standard deviation of baseline Chi-Squared divergence.
    """
    # Convert to NumPy arrays
    x_old_sample = x_old_sample.detach().cpu().numpy()
    x_new_sample = x_new_sample.detach().cpu().numpy()

    # Ensure tensor dimensions match
    assert x_old_sample.shape == x_new_sample.shape, "Tensor shapes don't match"

    # Shape: (B, S, L)
    B, S, L = x_old_sample.shape
    half = S // 2

    # Small epsilon to avoid division by zero
    epsilon = 1e-4

    # --- Compute all metrics ---
    results = {}

    # For fair comparison, use the first half of samples for both true and baseline metrics
    x_old_first_half = x_old_sample[:, :half, :]
    x_new_first_half = x_new_sample[:, :half, :]
    x_old_second_half = x_old_sample[:, half:, :]

    # Make sure splits are equal size
    assert x_old_first_half.shape == x_old_second_half.shape, "Half splits aren't equal"
    assert x_old_first_half.shape == x_new_first_half.shape, "Old and new sample splits aren't equal"

    # Compute marginal Chi-Squared (using first half only for consistency)
    if True:  # Condition to avoid commenting out
        results["Chi-Squared Marginal"] = _compute_marginal_chi_squared(
            x_old_first_half, x_new_first_half, dim, epsilon
        )

    # Compute per-item Chi-Squared
    chi_squared_per_item = _compute_per_item_chi_squared(x_old_first_half, x_new_first_half, dim, epsilon)
    results["Chi-Squared Avg Per Item"] = chi_squared_per_item.mean()
    results["Chi-Squared Avg Per Item Std"] = chi_squared_per_item.std()

    # Compute baseline Chi-Squared
    chi_squared_baseline_per_item = _compute_per_item_chi_squared(x_old_first_half, x_old_second_half, dim, epsilon)
    results["Chi-Squared Avg Per Item Baseline"] = chi_squared_baseline_per_item.mean()
    results["Chi-Squared Avg Per Item Baseline Std"] = chi_squared_baseline_per_item.std()

    return results


def _compute_marginal_chi_squared(
    x_p: np.ndarray, x_q: np.ndarray, dim: int, epsilon: float
) -> float:
    """Compute Chi-Squared divergence for marginal distributions."""
    # Flatten all dimensions to get overall distribution
    counts_p = np.bincount(x_p.flatten(), minlength=dim)
    counts_q = np.bincount(x_q.flatten(), minlength=dim)

    # Compute percentages
    prob_p = counts_p / counts_p.sum()
    prob_q = counts_q / counts_q.sum()

    # Add epsilon to avoid division by zero
    prob_q = np.clip(prob_q, epsilon, 1)

    # Compute Chi-Squared divergence: sum((p - q)^2 / q)
    return np.sum((prob_p - prob_q) ** 2 / prob_q)


def _compute_per_item_chi_squared(
    x_p: np.ndarray, x_q: np.ndarray, dim: int, epsilon: float
) -> np.ndarray:
    """Compute Chi-Squared divergence per item."""
    B, S, L = x_p.shape
    chi_squared_values = np.zeros((B, L))

    for b in range(B):
        for l in range(L):
            # Compute counts for this item
            counts_p = np.bincount(x_p[b, :, l], minlength=dim)
            counts_q = np.bincount(x_q[b, :, l], minlength=dim)

            # Compute probabilities
            prob_p = counts_p / counts_p.sum()
            prob_q = counts_q / counts_q.sum()

            # Add epsilon to avoid division by zero
            prob_q = np.clip(prob_q, epsilon, 1)

            # Compute Chi-Squared divergence
            chi_squared_values[b, l] = np.sum((prob_p - prob_q) ** 2 / prob_q)

    return chi_squared_values


def compare_log_likelihoods(log_prob_old: Tensor, log_prob_new: Tensor) -> dict[str, float]:
    """
    Compare log probabilities using various metrics.

    Args:
        log_prob_old (torch.Tensor): Log probabilities from the old model, shape (B, S, L).
        log_prob_new (torch.Tensor): Log probabilities from the new model, shape (B, S, L).

    Returns:
        dict: Contains metrics for log-likelihood ratio, log-likelihood improvement, rank correlation,
              KL divergence, and baseline comparisons.
    """
    # Preprocessing
    epsilon = np.log(1e-8)
    log_prob_old = log_prob_old.detach().cpu().numpy()
    log_prob_new = log_prob_new.detach().cpu().numpy()
    log_prob_old = np.clip(log_prob_old, epsilon, None)
    log_prob_new = np.clip(log_prob_new, epsilon, None)

    # Shape: (B, S, L)
    B, S, L = log_prob_old.shape

    # Split log_prob_old into halves for baseline
    half = S // 2
    log_prob_old_first = log_prob_old[:, :half, :]
    log_prob_old_second = log_prob_old[:, half:, :]
    log_prob_new = log_prob_new[:, :half, :]

    # Calculate all metrics
    results = {}

    # Add log likelihood ratio metrics
    results.update(_compute_log_likelihood_ratio_metrics(
        log_prob_old_first, log_prob_new, log_prob_old_second
    ))

    # Add log likelihood sequence metrics
    results.update(_compute_log_likelihood_sequence_metrics(
        log_prob_old_first, log_prob_new, log_prob_old_second
    ))

    # Add rank correlation metrics
    results.update(_compute_rank_correlation_metrics(
        log_prob_old_first, log_prob_new, log_prob_old_second
    ))

    # Add KL divergence metrics
    results.update(_compute_kl_divergence_metrics(
        log_prob_old_first, log_prob_new, log_prob_old_second
    ))

    return results


def _compute_log_likelihood_ratio_metrics(
    old_first: np.ndarray, new: np.ndarray, old_second: np.ndarray
) -> dict[str, float]:
    """Calculate log likelihood ratio metrics."""
    # True metrics
    llr = new - old_first
    llr_mean = llr.mean(axis=(1, 2))
    llr_std = llr.std(axis=(1, 2))

    # Baseline metrics
    baseline_llr = old_second - old_first  # Use same direction as true metric (new - old)
    baseline_llr_mean = baseline_llr.mean(axis=(1, 2))
    baseline_llr_std = baseline_llr.std(axis=(1, 2))

    return {
        "Log-Likelihood Ratio Mean": llr_mean.mean(),
        "Log-Likelihood Ratio Std": llr_std.mean(),
        "Baseline Log-Likelihood Ratio Mean": baseline_llr_mean.mean(),
        "Baseline Log-Likelihood Ratio Std": baseline_llr_std.mean(),
        "Overall Baseline Log-Likelihood Ratio Mean": baseline_llr.mean(),
    }


def _compute_log_likelihood_sequence_metrics(
    old_first: np.ndarray, new: np.ndarray, old_second: np.ndarray
) -> dict[str, float]:
    """Calculate log likelihood sequence metrics."""
    # Sum across events (axis 2)
    ll_old = np.sum(old_first, axis=2)
    ll_old2 = np.sum(old_second, axis=2)
    ll_new = np.sum(new, axis=2)

    # Compute improvements
    ll_ratio_seq = ll_new - ll_old
    ll_ratio_seq_baseline = ll_old2 - ll_old

    return {
        "Log-Likelihood Ratio Seq Mean": ll_ratio_seq.mean(),
        "Log-Likelihood Ratio Seq Std": ll_ratio_seq.std(),
        "Log-Likelihood Ratio Seq Baseline Mean": ll_ratio_seq_baseline.mean(),
        "Log-Likelihood Ratio Seq Baseline Std": ll_ratio_seq_baseline.std(),
    }


def _compute_rank_correlation_metrics(
    old_first: np.ndarray, new: np.ndarray, old_second: np.ndarray
) -> dict[str, float]:
    """Calculate rank correlation metrics."""
    B, S, L = old_first.shape

    # True rank correlation
    true_corrs: list[float] = []
    for b in range(B):
        seq_corrs: list[float] = []
        for l in range(L):
            corr, _ = spearmanr(old_first[b, :, l], new[b, :, l])
            seq_corrs.append(corr if not np.isnan(corr) else 0.0)
        true_corrs.append(np.mean(seq_corrs))

    # Baseline rank correlation
    baseline_corrs: list[float] = []
    for b in range(B):
        seq_corrs: list[float] = []
        for l in range(L):
            corr, _ = spearmanr(old_first[b, :, l], old_second[b, :, l])
            seq_corrs.append(corr if not np.isnan(corr) else 0.0)
        baseline_corrs.append(np.mean(seq_corrs))

    return {
        "Rank Correlation Mean": np.mean(true_corrs),
        "Rank Correlation Std": np.std(true_corrs),
        "Baseline Rank Correlation Mean": np.mean(baseline_corrs),
        "Baseline Rank Correlation Std": np.std(baseline_corrs),
    }


def _compute_kl_divergence_metrics(
    old_first: np.ndarray, new: np.ndarray, old_second: np.ndarray
) -> dict[str, float]:
    """Calculate KL divergence metrics."""
    # Convert to probabilities with proper clipping
    prob_old = np.exp(old_first)

    # True KL divergence
    kl = np.sum(prob_old * (old_first - new), axis=1)
    kl_mean = kl.mean(axis=1)
    kl_std = kl.std(axis=1)

    # Baseline KL divergence - using same numerical approach as true KL
    prob_old_first = np.exp(old_first)

    baseline_kl = np.sum(
        prob_old_first * (old_first - old_second),
        axis=1,
    )
    baseline_kl_mean = baseline_kl.mean(axis=1)
    baseline_kl_std = baseline_kl.std(axis=1)

    return {
        "KL Divergence Mean": kl_mean.mean(),
        "KL Divergence Std": kl_std.mean(),
        "Overall KL Divergence Mean": kl.mean(),
        "Baseline KL Divergence Mean": baseline_kl_mean.mean(),
        "Baseline KL Divergence Std": baseline_kl_std.mean(),
        "Overall Baseline KL Divergence Mean": baseline_kl.mean(),
    }


def compute_time_profile_overall(time_profile: list[dict[str, float]], name: str) -> dict[str, dict[str, float]]:
    # Aggregate values for each key
    aggregated = {}
    for entry in time_profile:
        for key, value in entry.items():
            if key not in aggregated:
                aggregated[key] = []
            aggregated[key].append(value)

    # Compute mean and standard deviation for each key
    overall_stats = {}
    for key, values in aggregated.items():
        values = np.array(values)
        overall_stats[f'{name} {key}'] = {
            "Mean": values.mean(),
            "Std": values.std()
        }

    return overall_stats


def compute_metrics(data: dict[str, Tensor]) -> dict[str, float]:
    dim = max(data['x_old_sample'].max().item(), data['x_new_sample'].max().item()) + 1

    kl_metrics = kl_divergence(data['x_old_sample'], data['x_new_sample'], dim)
    mmd_metrics = compute_mmd_with_variance(data['t_old_sample'], data['t_new_sample'])
    chi_metrics = chi_squared_divergence(data['x_old_sample'], data['x_new_sample'], dim)
    ll_metrics = compare_log_likelihoods(data['log_prob_old'], data['log_prob_new'])

    time_profile_old_stats = compute_time_profile_overall(data['time_profile_old'], 'Usual')
    time_profile_new_stats = compute_time_profile_overall(data['time_profile_new'], 'Speculative')

    return {
        **kl_metrics,
        **mmd_metrics,
        **chi_metrics,
        **ll_metrics,
        'Time old mean': np.mean(data['time_old']),
        'Time old std': np.std(data['time_old']),
        'Time new mean': np.mean(data['time_new']),
        'Time new std': np.std(data['time_new']),
        **time_profile_old_stats,
        **time_profile_new_stats,
        'Step': data['all_inds'].float().mean().item() + 1,
        'Mark const mean': torch.cat(data['mark_consts']).float().mean().item(),
        'Time const mean': torch.cat(data['time_consts']).float().mean().item(),
    }


def process_results_to_dataframe(results: list[dict], to_str: bool = True) -> pd.DataFrame:
    df = pd.DataFrame(results)

    combine_columns = {
        "KL Avg Per Item": ("KL Avg Per Item", "KL Avg Per Item Std"),
        "KL Avg Per Item Baseline": ("KL Avg Per Item Baseline", "KL Avg Per Item Baseline Std"),
        "MMD": ("MMD Mean", "MMD Std"),
        "Baseline MMD": ("Baseline MMD", "Baseline MMD Std"),
        "Chi-Squared Avg Per Item": ("Chi-Squared Avg Per Item", "Chi-Squared Avg Per Item Std"),
        "Chi-Squared Avg Per Item Baseline": ("Chi-Squared Avg Per Item Baseline", "Chi-Squared Avg Per Item Baseline Std"),
        "Log-Likelihood Ratio": ("Log-Likelihood Ratio Mean", "Log-Likelihood Ratio Std"),
        "Baseline Log-Likelihood Ratio": ("Baseline Log-Likelihood Ratio Mean", "Baseline Log-Likelihood Ratio Std"),
        "KL Divergence": ("KL Divergence Mean", "KL Divergence Std"),
        "Baseline KL Divergence": ("Baseline KL Divergence Mean", "Baseline KL Divergence Std"),
        "Rank Correlation": ("Rank Correlation Mean", "Rank Correlation Std"),
        "Baseline Rank Correlation": ("Baseline Rank Correlation Mean", "Baseline Rank Correlation Std"),
        "Time old": ("Time old mean", "Time old std"),
        "Time new": ("Time new mean", "Time new std"),
    }

    # Combine mean and std columns into a single column
    for new_col, (mean_col, std_col) in combine_columns.items():
        df[new_col] = df.apply(
            lambda row: f"{round(row[mean_col], 2)}±{round(row[std_col], 2)}", axis=1
        )
        # Drop the original mean and std columns
        if new_col != mean_col:
            df.drop(columns=[mean_col], inplace=True)
        df.drop(columns=[std_col], inplace=True)

    # Handle columns with dictionary values containing 'Mean' and 'Std'
    if to_str:
        for col in df.columns:
            if isinstance(df[col].iloc[0], dict) and 'Mean' in df[col].iloc[0] and 'Std' in df[col].iloc[0]:
                # Combine 'Mean' and 'Std' into a single column
                if 'Usual' in col or 'Speculative' in col:
                    df[col] = df[col].apply(lambda x: f"{round(x['Mean'], 2)}±{round(x['Std'], 2)}")
                else:
                    df[col] = df[col].apply(lambda x: f"{round(x['Mean'], 4)}±{round(x['Std'], 4)}")

        # Round remaining float columns (not combined into mean±std)
        float_cols = df.select_dtypes(include=["float"]).columns
        df[float_cols] = df[float_cols].map(lambda x: str(round(x, 4)))

    return df
