from pathlib import Path

import pandas as pd
import torch
import typer
from lightning_fabric import seed_everything
from loguru import logger
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold, cross_validate
from torch import Tensor

from hallucinations.dirs import DatasetDir

SCORING = ["accuracy", "precision", "recall", "f1", "roc_auc", "average_precision"]
BEST_METRIC = "test_roc_auc"
NUM_CV_SPLITS = 5
RESULTS_FILE = "attn_score_probe_results.json"
BEST_RESULTS_FILE = "best_attn_score_results.json"


def main(
    dataset_dir: Path = typer.Argument(..., help="Path to the dataset directory"),
) -> None:
    ds_dir = DatasetDir(dataset_dir)

    labels, valid_label_idx = load_labels(ds_dir)
    labels = labels[valid_label_idx]
    logger.info(f"Loaded {len(labels)} valid samples")

    cv_folds = load_or_create_cv_folds(ds_dir, labels)

    attn_diags = load_attn_diags(ds_dir, valid_label_idx)
    logger.info(f"Loaded attention diagonals for {len(attn_diags)} samples")

    results = []

    # Compute features
    attn_score_features = compute_attn_score(attn_diags)
    attn_log_det_features = compute_attn_log_det(attn_diags)

    logger.info(f"Attention score features shape: {attn_score_features.shape}")
    logger.info(f"Attention log det features shape: {attn_log_det_features.shape}")

    # Attention Score - Naive
    logger.info("Running naive probing for attention score...")
    attn_score_naive = run_naive_probing(
        features=attn_score_features,
        labels=labels,
        cv_folds=cv_folds,
        feature_type="attn_score",
        probe_type="naive",
    )
    if attn_score_naive is not None:
        results.append(attn_score_naive)

    # Attention Score - Trained
    logger.info("Running trained probing for attention score...")
    attn_score_trained = run_trained_probing(
        features=attn_score_features,
        labels=labels,
        cv_folds=cv_folds,
        feature_type="attn_score",
        probe_type="trained",
    )
    results.append(attn_score_trained)

    # Attention Log Det - Naive
    logger.info("Running naive probing for attention log det...")
    attn_log_det_naive = run_naive_probing(
        features=attn_log_det_features,
        labels=labels,
        cv_folds=cv_folds,
        feature_type="attn_log_det",
        probe_type="naive",
    )
    if attn_log_det_naive is not None:
        results.append(attn_log_det_naive)

    # Attention Log Det - Trained
    logger.info("Running trained probing for attention log det...")
    attn_log_det_trained = run_trained_probing(
        features=attn_log_det_features,
        labels=labels,
        cv_folds=cv_folds,
        feature_type="attn_log_det",
        probe_type="trained",
    )
    results.append(attn_log_det_trained)

    _save_results(results, ds_dir.root_dir / RESULTS_FILE)

    best_results_file = ds_dir.root_dir / BEST_RESULTS_FILE
    _compute_and_save_best_results(results, best_results_file)


def load_attn_diags(ds_dir: DatasetDir, valid_label_idx: Tensor) -> list[Tensor]:
    if ds_dir.internal_states_metrics_file.exists():
        logger.info(f"Loading attn_diags from {ds_dir.internal_states_metrics_file}")
        metrics = torch.load(ds_dir.internal_states_metrics_file, weights_only=False)
        attn_metrics = metrics["attn_metrics"]
        return [attn_metrics[idx]["attn_diags"].float() for idx in valid_label_idx]

    if ds_dir.attn_diags_file.exists():
        logger.info(f"Loading attn_diags from {ds_dir.attn_diags_file}")
        attn_diags = torch.load(ds_dir.attn_diags_file, mmap=True, map_location="cpu")
        return [attn_diags[idx].float() for idx in valid_label_idx]

    raise FileNotFoundError(
        f"No attn_diags found in {ds_dir.internal_states_metrics_file} or {ds_dir.attn_diags_file}"
    )


def load_labels(ds_dir: DatasetDir) -> tuple[Tensor, Tensor]:
    label_info = ds_dir.load_labels()
    labels = label_info["labels"]
    valid_label_idx = torch.arange(len(labels))[label_info["valid_labels_mask"]]

    label_map = {0: "non-hallucinated", 1: "hallucinated"}
    label_dist = {
        label_map[i]: count.item()
        for i, count in enumerate(torch.bincount(labels[valid_label_idx]))
    }
    logger.info(f"Label distribution: {label_dist}")

    return labels, valid_label_idx


def load_or_create_cv_folds(
    ds_dir: DatasetDir,
    labels: Tensor,
    num_cv_splits: int = NUM_CV_SPLITS,
) -> list[tuple[Tensor, Tensor]]:
    if ds_dir.cv_folds_file.exists():
        logger.info(f"Loading existing CV folds from {ds_dir.cv_folds_file}")
        return ds_dir.load_cv_folds()

    logger.info(f"Creating new CV folds and saving to {ds_dir.cv_folds_file}")
    seed_everything(42)

    kfold = StratifiedKFold(n_splits=num_cv_splits, shuffle=True, random_state=42)
    folds = [
        (torch.tensor(train_idx), torch.tensor(test_idx))
        for train_idx, test_idx in kfold.split(labels, labels)
    ]

    ds_dir.save_cv_folds(folds)
    return folds


def compute_attn_score(attn_diags: list[Tensor]) -> Tensor:
    """Compute attention score: -sum(mean(log(attn_diag))) over heads per layer.

    Input: attn_diags with shape [layers, heads, seq_len] per example
    Output: [examples, layers] - one scalar per layer
    """
    scores = []
    for diag in attn_diags:
        # diag: [layers, heads, seq_len]
        # log mean over tokens, then sum over heads, then negate
        log_mean = diag.log().mean(dim=-1)  # [layers, heads]
        layer_score = -log_mean.sum(dim=-1)  # [layers]
        scores.append(layer_score)
    return torch.stack(scores)


def compute_attn_log_det(attn_diags: list[Tensor]) -> Tensor:
    """Compute attention log det: mean(log(attn_diag)) per head.

    Input: attn_diags with shape [layers, heads, seq_len] per example
    Output: [examples, layers * heads] - one scalar per head
    """
    log_dets = []
    for diag in attn_diags:
        # diag: [layers, heads, seq_len]
        log_mean = diag.log().mean(dim=-1)  # [layers, heads]
        log_dets.append(log_mean.flatten())
    return torch.stack(log_dets)


def run_naive_probing(
    features: Tensor,
    labels: Tensor,
    cv_folds: list[tuple[Tensor, Tensor]],
    **metadata: str,
) -> pd.DataFrame | None:
    # Average multi-dimensional features to scalar
    xs = features.mean(dim=-1).numpy()
    ys = labels.numpy()

    if xs.std() == 0:
        return None

    fold_results = []
    for _, test_idx in cv_folds:
        test_xs = xs[test_idx.numpy()]
        test_ys = ys[test_idx.numpy()]

        if len(set(test_ys)) < 2:
            continue

        auc_pos = roc_auc_score(test_ys, test_xs)
        auc_neg = roc_auc_score(test_ys, -test_xs)
        best_auc = max(auc_pos, auc_neg)

        fold_results.append({"test_roc_auc": best_auc})

    if not fold_results:
        return None

    results = pd.DataFrame(fold_results)
    for key, value in metadata.items():
        results[key] = value

    return results


def run_trained_probing(
    features: Tensor,
    labels: Tensor,
    cv_folds: list[tuple[Tensor, Tensor]],
    **metadata: str,
) -> pd.DataFrame:
    xs = features.numpy()
    ys = labels.numpy()

    seed_everything(42)

    model = LogisticRegression(
        max_iter=1_000,
        class_weight="balanced",
    )

    cv_splits = [(train_idx.numpy(), test_idx.numpy()) for train_idx, test_idx in cv_folds]

    results = cross_validate(
        model,
        xs,
        ys,
        cv=cv_splits,
        scoring=SCORING,
    )
    results = pd.DataFrame(results)
    for key, value in metadata.items():
        results[key] = value

    return results


def _save_results(results: list[pd.DataFrame], results_file: Path) -> None:
    results_all = pd.concat(results, axis=0)
    results_all.to_json(
        results_file,
        indent=2,
        orient="records",
        index=False,
    )
    logger.info(f"Results saved to {results_file}")


def _compute_and_save_best_results(
    results: list[pd.DataFrame],
    best_results_file: Path,
) -> None:
    if not results:
        logger.warning("No results to compute best results from.")
        return

    results_all = pd.concat(results, axis=0)

    metric_cols = [f"test_{m}" for m in SCORING]
    available_metric_cols = [c for c in metric_cols if c in results_all.columns]

    groupby_cols = ["feature_type", "probe_type"]
    agg_results = (
        results_all.groupby(groupby_cols, dropna=False)[available_metric_cols].mean().reset_index()
    )

    logger.info(f"Results per configuration (by {BEST_METRIC}):\n{agg_results.to_string()}")

    agg_results.to_json(best_results_file, indent=2, orient="records", index=False)
    logger.info(f"Best results saved to {best_results_file}")


if __name__ == "__main__":
    typer.run(main)
