import json
import re
from pathlib import Path

import pandas as pd
import torch
import typer
from datasets import load_dataset
from loguru import logger
from tabulate import tabulate
from tqdm.auto import tqdm

from hallucinations.utils.misc import get_ds_dir_tokenizer
from hallucinations.dirs import DatasetDir

# Credit: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k/gsm8k-cot-llama.yaml
FINAL_ANSWER_REGEX = r"(?i)The final answer is[:\s]*([*]*\s*-?[$0-9.,]*[0-9]+[$0-9.,]*\s*[*]*)"
IGNORE_ANSWER_PATTERNS = [
    r",",
    r"\$",
    r"(?s).*#### ",
    r"\.$",
    r"\*",
]

VALID_ANSWER_STOP_REASON = "eos_token"


def get_shards_dir(ds_dir: DatasetDir) -> Path:
    """Get the directory containing activation shards, trying multiple locations."""
    candidates = [
        ds_dir.full_hidden_states_with_attentions_dir,
        ds_dir.attentions_dir,
        ds_dir.hidden_states_dir,
    ]
    for candidate in candidates:
        if candidate.exists():
            return candidate
    raise FileNotFoundError(
        f"No activation shards directory found. Tried: {[str(c) for c in candidates]}"
    )


def main(
    dataset_dir: Path = typer.Option(..., help="Path to the dataset directory"),
) -> None:
    ds_dir = DatasetDir(dataset_dir)
    ds_config = ds_dir.load_dataset_config()

    df = pd.read_json(ds_dir.answers_file)
    df["final_answer"] = df["prediction"].apply(extract_final_answer_from_prediction)

    logger.info(f"Number of missing final answers: {df['final_answer'].isnull().sum()}")

    if "stop_reason" not in df.columns:
        logger.warning("Adding stop_reason column to answers based on shards")
        shards_dir = get_shards_dir(ds_dir)
        tokenizer = get_ds_dir_tokenizer(ds_dir)
        stop_reasons = []
        for shard_file in tqdm(
            ds_dir.get_sorted_shards(shards_dir),
            desc="Computing stop reasons",
        ):
            shard = torch.load(shard_file, mmap=True, map_location="cpu")
            toks = shard["generated_tokens"][0]
            stop_reasons.append("eos_token" if toks[-1] == tokenizer.eos_token_id else "max_length")
        df["stop_reason"] = stop_reasons

    logger.info(f"Stop reasons: {df['stop_reason'].value_counts().to_dict()}")

    dataset = load_dataset("gsm8k", "main", split=ds_config.test_split_name)
    dataset = dataset.map(extract_exact_answer_from_gold, batched=False)
    dataset = dataset.to_pandas()

    res = df.join(dataset[["answer", "exact_answer"]], how="right")

    assert (res["answer"] == res["gold"]).all(), "Gold answers and LLM answers have different order"
    res = res.drop(columns=["answer"])

    labels, valid_labels_mask = compute_gsm8k_labels(res, df["stop_reason"])

    n_hallucinated = labels[valid_labels_mask].sum().item()
    n_non_hallucinated = (labels[valid_labels_mask] == 0).sum().item()
    n_missing_final_answer = df["final_answer"].isnull().sum()
    n_invalid_stop_reason = (df["stop_reason"] != VALID_ANSWER_STOP_REASON).sum()
    accuracy = (labels[valid_labels_mask] == 0).float().mean().item()

    logger.info(f"Number of hallucinated answers: {n_hallucinated}")
    logger.info(f"Number of non-hallucinated answers: {n_non_hallucinated}")
    logger.info(f"LLM accuracy: {accuracy:0.3f}")

    res["label"] = labels.numpy()
    for label_val, label_name in [(1, "hallucinated"), (0, "non-hallucinated")]:
        subset = res[res["label"] == label_val]
        if not subset.empty:
            print(f"\nRandom 3 examples of {label_name} answers:")
            rows = []
            for idx, row in subset.sample(n=min(5, len(subset))).iterrows():
                snippet = row["prediction"][-150:].replace("\n", "\\n")
                rows.append([idx, f"...{snippet}", row["final_answer"], row["exact_answer"]])
            print(
                tabulate(
                    rows,
                    headers=["Id", "Prediction", "Extracted", "Gold"],
                    tablefmt="grid",
                    maxcolwidths=[None, 100, None, None],
                )
            )

    torch.save(
        {"labels": labels, "valid_labels_mask": valid_labels_mask},
        ds_dir.labels_file,
    )

    stats = {
        "n_total": len(labels),
        "n_hallucinated": n_hallucinated,
        "n_non_hallucinated": n_non_hallucinated,
        "n_missing_final_answer": int(n_missing_final_answer),
        "n_invalid_stop_reason": int(n_invalid_stop_reason),
        "accuracy": accuracy,
        "answer_stop_reasons": {k: int(v) for k, v in df["stop_reason"].value_counts().items()},
    }
    stats_file = ds_dir.root_dir / "labels_stats.json"
    with open(stats_file, "w") as f:
        json.dump(stats, f, indent=2)
    logger.info(f"Saved stats to {stats_file}")


def extract_exact_answer_from_gold(item: dict[str, str]) -> dict[str, str]:
    return {"exact_answer": item["answer"].partition("####")[2].strip()}


def extract_final_answer_from_prediction(prediction: str) -> str | None:
    match = re.search(FINAL_ANSWER_REGEX, prediction)
    if match:
        final_answer = match.group(1).strip()
        for ignore_pattern in IGNORE_ANSWER_PATTERNS:
            final_answer = re.sub(ignore_pattern, "", final_answer)
        return final_answer.strip()
    else:
        return None


def compute_gsm8k_labels(
    res: pd.DataFrame, stop_reason: pd.Series
) -> tuple[torch.Tensor, torch.Tensor]:
    is_correct = res.apply(lambda row: is_equiv(row["final_answer"], row["exact_answer"]), axis=1)
    labels = torch.tensor(~is_correct, dtype=torch.long)

    missing_final_answer = res["final_answer"].isnull()
    invalid_stop_reason = stop_reason != VALID_ANSWER_STOP_REASON

    if missing_final_answer.sum() > 0:
        logger.info(f"Setting {missing_final_answer.sum()} labels to -1 (missing final answer)")
    if invalid_stop_reason.sum() > 0:
        logger.warning(f"Setting {invalid_stop_reason.sum()} labels to -1 (invalid stop_reason)")

    labels[missing_final_answer.values | invalid_stop_reason.values] = -1
    valid_labels_mask = labels >= 0
    return labels, valid_labels_mask


def is_equiv(pred: str | float | None, gold: str | float | None) -> bool:
    if pd.isna(pred) or pd.isna(gold):
        return False
    try:
        return float(str(pred).strip().replace(",", "").rstrip(".")) == float(
            str(gold).strip().replace(",", "").rstrip(".")
        )
    except ValueError:
        return str(pred).strip().rstrip(".") == str(gold).strip().rstrip(".")


if __name__ == "__main__":
    typer.run(main)
