from pathlib import Path
from typing import Any

import pandas as pd
import torch
import typer
from lightning_fabric import seed_everything
from loguru import logger
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold, cross_validate
from torch import Tensor
from tqdm.auto import tqdm, trange
from transformers import AutoConfig, PretrainedConfig

from hallucinations.dirs import DatasetDir, TeacherForcingDatasetDir
from hallucinations.features.sink_scores import (
    compute_sink_scores_from_laplacian_and_attention_diags,
)

SCORING = ["accuracy", "precision", "recall", "f1", "roc_auc", "average_precision"]
BEST_METRIC = "test_roc_auc"
METRICS_FILE = "internal_states_metrics.pt"
NUM_CV_SPLITS = 5
TOP_K = [1, 2, 3, 4, 5, 10, 25, 50, 100]
BEST_RESULTS_FILE = "best_probe_results.json"


def main(
    dataset_dir: Path = typer.Argument(..., help="Path to the dataset directory"),
) -> None:
    ds_dir = DatasetDir(dataset_dir)
    tf_dir = TeacherForcingDatasetDir(dataset_dir)
    metrics_file = ds_dir.root_dir / "layer_subset_metrics.json"

    if ds_dir.internal_states_metrics_file.exists():
        logger.info(
            "Loading data from internal states metrics file: {fname}...",
            fname=ds_dir.internal_states_metrics_file,
        )
        attn_metrics, llm_config = load_data(ds_dir)
    elif tf_dir.features_file.exists():
        logger.info(
            "Loading data from teacher forcing features file: {fname}...",
            fname=tf_dir.features_file,
        )
        attn_metrics, llm_config = load_data(tf_dir)
    elif tf_dir.attentions_dir.exists():
        logger.info(
            "Loading data from teacher forcing attentions directory: {fname}...",
            fname=tf_dir.attentions_dir,
        )
        attn_metrics, llm_config = load_tf_data_from_shards(tf_dir)
    else:
        logger.warning(
            "Loading data from legacy features files: {attn_fname} and {lap_fname}...",
            attn_fname=ds_dir.attn_diags_file,
            lap_fname=ds_dir.laplacian_diags_file,
        )
        attn_metrics, llm_config = load_data_from_legacy_features(ds_dir)

    labels, valid_label_idx = load_labels(ds_dir)
    labels = labels[valid_label_idx]

    cv_folds = load_or_create_cv_folds(ds_dir, labels)

    results = []

    try:
        ### MTopDiv features (first) ###
        mtopdiv_features = extract_scalar_features(attn_metrics, valid_label_idx, "mtopdiv")
        if mtopdiv_features is not None:
            logger.info("Running probing for MTopDiv features...")
            mtopdiv_results = run_probing(
                llm_config=llm_config,
                features=mtopdiv_features.unsqueeze(-1),
                labels=labels,
                cv_folds=cv_folds,
                feature="mtopdiv",
                top_k=None,
            )
            results.append(mtopdiv_results)
            _save_results(results, metrics_file)
        else:
            logger.warning("Skipping MTopDiv: feature not found in internal_states_metrics.pt")

        ### Lookback Lens features (second) ###
        lookback_lens_features = extract_scalar_features(
            attn_metrics, valid_label_idx, "lookback_lens"
        )
        if lookback_lens_features is not None:
            logger.info("Running probing for Lookback Lens features...")
            lookback_lens_results = run_probing(
                llm_config=llm_config,
                features=lookback_lens_features.unsqueeze(-1),
                labels=labels,
                cv_folds=cv_folds,
                feature="lookback_lens",
                top_k=None,
            )
            results.append(lookback_lens_results)
            _save_results(results, metrics_file)
        else:
            logger.warning(
                "Skipping Lookback Lens: feature not found in internal_states_metrics.pt"
            )

        for top_k in tqdm(TOP_K, desc="TopEigvals"):
            ### Attention features ###
            logger.info("Running probing for attention features...")
            attn_eigvals_features = prepare_topk_features(
                llm_config=llm_config,
                feat_name="attention_eigvals_sorted",
                attn_metrics=attn_metrics,
                valid_label_idx=valid_label_idx,
                top_k=top_k,
            )
            attn_results = run_probing(
                llm_config=llm_config,
                features=attn_eigvals_features,
                labels=labels,
                cv_folds=cv_folds,
                feature="attn_eigvals_topk",
                top_k=top_k,
            )
            results.append(attn_results)
            _save_results(results, metrics_file)

            ### Laplacian features ###
            logger.info("Running probing for Laplacian features...")
            lap_eigvals_features = prepare_topk_features(
                llm_config=llm_config,
                feat_name="laplacian_eigvals_sorted",
                attn_metrics=attn_metrics,
                valid_label_idx=valid_label_idx,
                top_k=top_k,
            )
            laplacian_results = run_probing(
                llm_config=llm_config,
                features=lap_eigvals_features,
                labels=labels,
                cv_folds=cv_folds,
                feature="laplacian_eigvals_topk",
                top_k=top_k,
            )
            results.append(laplacian_results)
            _save_results(results, metrics_file)

            ### Per-token sink score features ###
            logger.info("Running probing for per-token sink score features...")
            sink_score_per_token_features = prepare_topk_features(
                llm_config=llm_config,
                feat_name="sink_score_per_token_sorted",
                attn_metrics=attn_metrics,
                valid_label_idx=valid_label_idx,
                top_k=top_k,
            )
            sink_score_results = run_probing(
                llm_config=llm_config,
                features=sink_score_per_token_features,
                labels=labels,
                cv_folds=cv_folds,
                feature="sink_score_topk",
                top_k=top_k,
            )
            results.append(sink_score_results)
            _save_results(results, metrics_file)

    except KeyboardInterrupt:
        logger.error("Keyboard interrupt detected. Exiting...")
    else:
        logger.info("All experiments completed successfully.")

    # Display and save best results for each feature type
    best_results_file = ds_dir.root_dir / BEST_RESULTS_FILE
    _compute_and_save_best_results(results, best_results_file)


def load_data(ds_dir: DatasetDir | TeacherForcingDatasetDir) -> tuple[list[dict], PretrainedConfig]:
    config = ds_dir.load_config()
    llm_config = AutoConfig.from_pretrained(config.llm.name)

    if isinstance(ds_dir, TeacherForcingDatasetDir):
        metrics_file = ds_dir.features_file
    else:
        metrics_file = ds_dir.internal_states_metrics_file

    metrics = torch.load(metrics_file)
    attn_metrics = metrics["attn_metrics"]

    if attn_metrics and "sink_score_per_token_sorted" not in attn_metrics[0]:
        sink_scores = compute_sink_scores_from_laplacian_and_attention_diags(metrics)
        for item, sink_score in zip(attn_metrics, sink_scores):
            item["sink_score_per_token_sorted"] = sink_score.sort(dim=-1, descending=True).values

    for item in attn_metrics:
        if "attention_eigvals_sorted" not in item:
            item["attention_eigvals_sorted"] = (
                item["attn_diags"].sort(dim=-1, descending=True).values
            )
        if "laplacian_eigvals_sorted" not in item:
            item["laplacian_eigvals_sorted"] = (
                item["laplacian_diags"].sort(dim=-1, descending=True).values
            )

    return attn_metrics, llm_config


def load_tf_data_from_shards(
    tf_dir: TeacherForcingDatasetDir,
) -> tuple[dict[str, list[Any]], PretrainedConfig]:
    config = tf_dir.load_config()
    llm_config = AutoConfig.from_pretrained(config.llm.name)

    attn_metrics = []
    for shard_file in tqdm(
        tf_dir.get_sorted_shards(tf_dir.attentions_dir), desc="Loading attentions shards"
    ):
        shard = torch.load(shard_file, weights_only=True, mmap=True, map_location="cpu")
        for item in shard["features"]["attn_metrics"]:
            attn_metrics.append(item)

    torch.save({"attn_metrics": attn_metrics}, tf_dir.features_file)

    return attn_metrics, llm_config


def load_data_from_legacy_features(ds_dir: DatasetDir) -> tuple[list[dict], PretrainedConfig]:
    config = ds_dir.load_config()
    llm_config = AutoConfig.from_pretrained(config.llm.name)

    attn_diags = torch.load(ds_dir.attn_diags_file, mmap=True, map_location="cpu")
    laplacian_diags = torch.load(ds_dir.laplacian_diags_file, mmap=True, map_location="cpu")

    assert isinstance(attn_diags, list)
    assert isinstance(laplacian_diags, list)
    assert len(attn_diags) == len(laplacian_diags)
    num_items = len(attn_diags)

    attn_metrics = []
    for idx in trange(num_items, desc="Loading attention metrics"):
        item_attn_diags = attn_diags[idx].float()
        item_laplacian_diags = laplacian_diags[idx].float()
        assert item_attn_diags.shape == item_laplacian_diags.shape
        sink_score_per_token_sorted, sink_score_per_token_sorted_idx = (
            item_laplacian_diags + item_attn_diags
        ).sort(dim=-1, descending=True)
        attn_eigvals_sorted, attn_eigvals_sorted_idx = torch.sort(
            item_attn_diags, dim=-1, descending=True
        )
        laplacian_eigvals_sorted, laplacian_eigvals_sorted_idx = torch.sort(
            item_laplacian_diags, dim=-1, descending=True
        )
        item_attn_metrics = {
            "attn_diags": item_attn_diags,
            "laplacian_diags": item_laplacian_diags,
            "sink_score_per_token_sorted": sink_score_per_token_sorted,
            "sink_score_per_token_sorted_idx": sink_score_per_token_sorted_idx,
            "attention_eigvals_sorted": attn_eigvals_sorted,
            "attention_eigvals_sorted_idx": attn_eigvals_sorted_idx,
            "laplacian_eigvals_sorted": laplacian_eigvals_sorted,
            "laplacian_eigvals_sorted_idx": laplacian_eigvals_sorted_idx,
        }
        attn_metrics.append(item_attn_metrics)
    return attn_metrics, llm_config


def extract_scalar_features(
    attn_metrics: list[dict[str, Any]],
    valid_label_idx: Tensor,
    feat_name: str,
) -> Tensor | None:
    """Extracts scalar features (shape [#layers, #heads]) from attn_metrics.

    Returns None if feature is not present in any sample.
    """
    if not attn_metrics or feat_name not in attn_metrics[0]:
        return None

    features = []
    for idx in valid_label_idx:
        feat = attn_metrics[idx].get(feat_name)
        if feat is None:
            return None
        features.append(feat)

    return torch.stack(features)


def load_labels(ds_dir: DatasetDir) -> tuple[Tensor, Tensor]:
    label_info = ds_dir.load_labels()
    labels = label_info["labels"]
    valid_label_idx = torch.arange(len(labels))[label_info["valid_labels_mask"]]

    label_map = {0: "non-hallucinated", 1: "hallucinated"}
    label_dist = {
        label_map[i]: count.item()
        for i, count in enumerate(torch.bincount(labels[valid_label_idx]))
    }
    logger.info(f"Label distribution: {label_dist}")

    return labels, valid_label_idx


def load_or_create_cv_folds(
    ds_dir: DatasetDir,
    labels: Tensor,
    num_cv_splits: int = NUM_CV_SPLITS,
) -> list[tuple[Tensor, Tensor]]:
    if ds_dir.cv_folds_file.exists():
        logger.info(f"Loading existing CV folds from {ds_dir.cv_folds_file}")
        return ds_dir.load_cv_folds()

    logger.info(f"Creating new CV folds and saving to {ds_dir.cv_folds_file}")
    seed_everything(42)

    kfold = StratifiedKFold(n_splits=num_cv_splits, shuffle=True, random_state=42)
    folds = [
        (torch.tensor(train_idx), torch.tensor(test_idx))
        for train_idx, test_idx in kfold.split(labels, labels)
    ]

    ds_dir.save_cv_folds(folds)
    return folds


def prepare_topk_features(
    llm_config: PretrainedConfig,
    feat_name: str,
    attn_metrics: list[dict[str, Any]],
    valid_label_idx: Tensor,
    top_k: int,
) -> Tensor:
    """Extracts top-k features for each layer and head.
    Output shape: [#examples, #layers, #heads, top_k]
    """
    num_layers = getattr(llm_config, "text_config", llm_config).num_hidden_layers
    num_heads = getattr(llm_config, "text_config", llm_config).num_attention_heads
    xs = []
    for idx in valid_label_idx:
        item_attn_metrics = attn_metrics[idx][feat_name]
        assert item_attn_metrics.dim() in (1, 3), (
            f"Feature {feat_name} must be either flat 1D (#layers * #heads * #features) or 3D ([#layers, #heads, #features]), but has shape: {item_attn_metrics.shape}"
        )
        feat = item_attn_metrics.reshape(num_layers, num_heads, -1)[:, :, :top_k]
        assert feat.size(-1) == top_k, (
            f"Chosen topk-k exceeds the number of features: {feat.size(-1)} > {top_k}"
        )
        xs.append(feat)
    return torch.stack(xs)


def run_probing(
    llm_config: PretrainedConfig,
    features: Tensor,
    labels: Tensor,
    cv_folds: list[tuple[Tensor, Tensor]],
    **metadata: Any,
) -> pd.DataFrame:
    """Run probing for the given features and labels.
    Features shape: [#examples, #layers, #heads, #features]
    Labels shape: [#examples]
    """
    num_layers = getattr(llm_config, "text_config", llm_config).num_hidden_layers
    ys = labels.numpy()

    seed_everything(42)

    start_idx = 0
    end_idx = num_layers
    xs = features[:, start_idx:end_idx].flatten(start_dim=1).numpy()

    model = LogisticRegression(
        max_iter=1_000,
        class_weight="balanced",
    )

    cv_splits = [(train_idx.numpy(), test_idx.numpy()) for train_idx, test_idx in cv_folds]

    results = cross_validate(
        model,
        xs,
        ys,
        cv=cv_splits,
        scoring=SCORING,
    )
    results = pd.DataFrame(results)
    results["start_layer"] = start_idx
    results["end_layer"] = end_idx
    for key, value in metadata.items():
        results[key] = value

    return results


def _save_results(results: list[pd.DataFrame], metrics_file: Path) -> None:
    results_all = pd.concat(results, axis=0)
    results_all.to_json(
        metrics_file,
        indent=2,
        orient="records",
        index=False,
    )
    logger.info(f"Results saved to {metrics_file}")


def _compute_and_save_best_results(
    results: list[pd.DataFrame],
    best_results_file: Path,
) -> None:
    if not results:
        logger.warning("No results to compute best results from.")
        return

    results_all = pd.concat(results, axis=0)

    group_cols = ["feature"]
    if "top_k" in results_all.columns:
        group_cols.append("top_k")

    metric_cols = [f"test_{m}" for m in SCORING]
    agg_results = results_all.groupby(group_cols, dropna=False)[metric_cols].mean().reset_index()

    best_idx = agg_results.groupby("feature")[BEST_METRIC].idxmax()
    best_df = agg_results.loc[best_idx]

    logger.info(f"Best results per feature type (by {BEST_METRIC}):\n{best_df.to_string()}")

    best_df.to_json(best_results_file, indent=2, orient="records", index=False)
    logger.info(f"Best results saved to {best_results_file}")


if __name__ == "__main__":
    typer.run(main)
