# metrics_utils.py
import torch
import math
import pandas as pd
import numpy as np
from tqdm import tqdm
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


def measure_neural_collapse(model, dataloader, num_classes, device):
    """
    Measure neural collapse metrics (NC1, NC2, NC3).
    This version strictly follows the logic from the original nc_experiment.py.
    """
    model.eval()
    features_list = []
    labels_list = []

    with torch.no_grad():
        for images_batch, labels_batch in dataloader:  # Using distinct loop variables
            images_batch = images_batch.to(device)
            labels_batch = labels_batch.to(device)

            outputs, features_from_batch = model(images_batch)  # Use distinct name

            features_list.append(features_from_batch)
            labels_list.append(labels_batch)

    if not features_list:
        logger.warning("measure_neural_collapse: No features collected. Dataloader might be empty.")
        return None, None, None

    # Rename after cat to match original variable names 'features' and 'labels'
    features = torch.cat(features_list)
    labels = torch.cat(labels_list)

    try:
        class_means = torch.stack([
            features[labels == c].mean(dim=0)
            for c in range(num_classes)
        ])
    except RuntimeError as e:
        logger.error(f"Error creating class_means in measure_neural_collapse (likely a missing class in batch): {e}. Returning None for NC metrics.")
        return None, None, None

    global_mean = features.mean(dim=0)

    # Between-class covariance
    Sb = ((class_means - global_mean) ** 2).mean()

    # Within-class covariance
    Sw = torch.tensor(0.0, device=device)
    num_classes_found_for_Sw = 0
    for c in range(num_classes):
        class_feat = features[labels == c]
        if len(class_feat) > 0:
            Sw += ((class_feat - class_means[c]) ** 2).mean()
            num_classes_found_for_Sw += 1

    if num_classes_found_for_Sw > 0:
        Sw /= num_classes
    else:
        pass

    # NC1 calculation and condition
    if Sw > 0:
        nc1 = Sw / Sb
    else:
        return None, None, None

    # NC2: Subspace alignment
    norm_class_means = torch.norm(class_means, dim=1, keepdim=True) + 1e-9
    normalized_class_means = class_means / norm_class_means

    gram_matrix = normalized_class_means @ normalized_class_means.T
    identity_matrix_nc2 = torch.eye(num_classes, device=device)
    nc2 = torch.norm(gram_matrix - identity_matrix_nc2, p='fro') / math.sqrt(
        num_classes)

    # NC3: Classifier alignment with ETF
    last_layer_weights = model.get_classifier_weights().to(device)
    sigma = torch.linalg.pinv(class_means @ class_means.T)

    v_phi = last_layer_weights @ class_means.T @ sigma

    v_phi_norm_fro = torch.norm(v_phi, p='fro') + 1e-9
    v_phi = v_phi / v_phi_norm_fro

    etf_matrix_base = torch.eye(num_classes, device=device) - (1 / num_classes) * torch.ones(num_classes, num_classes,
                                                                                             device=device)
    target_etf_for_nc3 = etf_matrix_base / math.sqrt(num_classes - 1)

    nc3 = torch.norm(v_phi - target_etf_for_nc3, p='fro')

    return nc1, nc2, nc3


def evaluate_accuracy(model, dataloader, device, desc="Evaluating"):
    """Evaluate model accuracy on a dataset."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc=desc, leave=False, disable=None):
            images = images.to(device)
            labels = labels.to(device)
            try:
                outputs, _ = model(images)
            except TypeError:
                outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    if total == 0:
        logger.warning(f"No samples found in dataloader for accuracy evaluation ({desc}). Returning 0 accuracy.")
        return 0.0
    accuracy = 100.0 * correct / total
    return accuracy


def save_metrics_to_csv(metrics_dict, epoch, output_file_path: Path):
    """Save training metrics to CSV file using pandas for robustness."""
    df_new_row = pd.DataFrame([metrics_dict], index=[epoch])
    df_new_row.index.name = 'epoch'

    try:
        if not output_file_path.exists() or epoch == 0:
            df_new_row.to_csv(output_file_path, mode='w', header=True)
        else:
            df_new_row.to_csv(output_file_path, mode='a', header=False)
    except IOError as e:
        logger.error(f"Could not write metrics to {output_file_path}: {e}")


def aggregate_run_metrics(metrics_files, num_epochs_config):
    """Aggregate metrics from multiple runs and calculate mean and standard deviation."""
    all_metrics_dfs = []
    logger.info("Aggregating metrics from multiple runs...")

    for metrics_file_path_str in metrics_files:
        metrics_file_path = Path(metrics_file_path_str)
        if not metrics_file_path.exists():
            logger.warning(f"Metrics file {metrics_file_path} not found and will be skipped.")
            continue
        try:
            df = pd.read_csv(metrics_file_path)
            if not df.empty:
                all_metrics_dfs.append(df)
            else:
                logger.warning(f"Metrics file {metrics_file_path} is empty and will be skipped.")
        except pd.errors.EmptyDataError:
            logger.warning(f"Metrics file {metrics_file_path} is empty (pandas EmptyDataError) and will be skipped.")
        except Exception as e:
            logger.error(f"Error reading metrics file {metrics_file_path}: {e}")

    if not all_metrics_dfs:
        logger.error("No valid, non-empty metrics files found for aggregation. Returning empty DataFrame.")
        return pd.DataFrame()

    max_epochs_recorded = 0
    for df_val in all_metrics_dfs:  # df is a reserved keyword in some linters if used as loop var
        if 'epoch' in df_val.columns:
            max_epochs_recorded = max(max_epochs_recorded, df_val['epoch'].max() + 1)

    if max_epochs_recorded == 0:
        max_epochs_recorded = num_epochs_config
        logger.warning(f"Could not determine max epochs from CSVs, using configured num_epochs: {num_epochs_config}")

    agg_df = pd.DataFrame({'epoch': range(max_epochs_recorded)})

    if not all_metrics_dfs[0].empty:
        metric_column_names = [col for col in all_metrics_dfs[0].columns if col != 'epoch']
    else:
        logger.warning("First metrics DataFrame is empty, cannot infer metric column names for aggregation.")
        return pd.DataFrame({'epoch': range(max_epochs_recorded)})

    for metric_col in metric_column_names:
        metric_series_list = []
        for df_run in all_metrics_dfs:
            if metric_col in df_run.columns and 'epoch' in df_run.columns:
                run_series = pd.Series(data=df_run[metric_col].values, index=df_run['epoch'])
                aligned_series = run_series.reindex(range(max_epochs_recorded))
                metric_series_list.append(aligned_series)
            else:
                metric_series_list.append(pd.Series([np.nan] * max_epochs_recorded, index=range(max_epochs_recorded)))

        if metric_series_list:
            metric_df_for_agg = pd.concat(metric_series_list, axis=1)
            agg_df[f'{metric_col}_mean'] = metric_df_for_agg.mean(axis=1)
            agg_df[f'{metric_col}_std'] = metric_df_for_agg.std(axis=1)
        else:
            agg_df[f'{metric_col}_mean'] = np.nan
            agg_df[f'{metric_col}_std'] = np.nan
            logger.warning(f"No data found for metric column '{metric_col}' during aggregation.")

    logger.info(f"Aggregated metrics DataFrame head:\n{agg_df.head()}")
    return agg_df