import json
from pathlib import Path
from typing import Any

import pandas as pd
import torch
import typer
from loguru import logger
from tqdm.auto import tqdm

from hallucinations.utils.misc import get_ds_dir_tokenizer
from hallucinations.dirs import DatasetDir

LABEL_MAPPING = {"correct": 0, "incorrect": 1}
VALID_ANSWER_STOP_REASON = "eos_token"
VALID_JUDGE_FINISH_REASON = "stop"


def load_json(path: Path) -> Any:
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path}")
    with open(path) as f:
        return json.load(f)


def main(
    dataset_dir: Path = typer.Option(..., help="Path to the dataset directory"),
    llm_judge_prompt: str = typer.Option(
        None, help="LLM judge prompt name (auto-detected if not provided)"
    ),
    llm_judge_llm: str = typer.Option(
        None, help="LLM judge model name (auto-detected if not provided)"
    ),
) -> None:
    ds_dir = DatasetDir(dataset_dir)

    judge_configs = ds_dir.list_llm_judge_configs()
    if not judge_configs:
        raise ValueError(f"No LLM judge configs found in {ds_dir.llm_judge_dir}")

    if llm_judge_prompt is None or llm_judge_llm is None:
        if len(judge_configs) > 1:
            logger.warning(f"Multiple LLM judge configs found: {judge_configs}. Using first one.")
        llm_judge_prompt = judge_configs[0]["llm_judge_prompt"]
        llm_judge_llm = judge_configs[0]["llm_judge_llm"]

    logger.info(f"Using LLM judge: {llm_judge_llm} with prompt: {llm_judge_prompt}")

    judge_results = load_json(ds_dir.llm_judge_file(llm_judge_prompt, llm_judge_llm))
    judge_metadata = load_json(ds_dir.llm_judge_metadata_file(llm_judge_prompt, llm_judge_llm))

    judge_finish_reasons = [
        m["response_metadata"]["finish_reason"] for m in judge_metadata["metadata"]
    ]

    answers = pd.read_json(ds_dir.answers_file)
    if "stop_reason" not in answers.columns:
        logger.warning("Adding stop_reason column to answers based on shards")
        tokenizer = get_ds_dir_tokenizer(ds_dir)
        stop_reasons = []
        for i, shard_file in enumerate(
            tqdm(
                ds_dir.get_sorted_shards(ds_dir.full_hidden_states_with_attentions_dir),
                desc="Computing stop reasons",
            )
        ):
            shard = torch.load(shard_file, mmap=True, map_location="cpu")
            toks = shard["generated_tokens"][0]
            stop_reasons.append("eos_token" if toks[-1] == tokenizer.eos_token_id else "max_length")
        answers["stop_reason"] = stop_reasons

    logger.info(f"Stop reasons: {answers['stop_reason'].value_counts().to_dict()}")

    answers["judge_label"] = judge_results
    answers["judge_finish_reason"] = judge_finish_reasons

    labels, valid_labels_mask = compute_labels(answers)
    stats = compute_stats(answers, labels, valid_labels_mask)
    print_stats(stats)

    torch.save({"labels": labels, "valid_labels_mask": valid_labels_mask}, ds_dir.labels_file)
    logger.info(f"Saved labels to {ds_dir.labels_file}")

    stats_file = ds_dir.root_dir / "labels_stats.json"
    with open(stats_file, "w") as f:
        json.dump(stats, f, indent=2)
    logger.info(f"Saved stats to {stats_file}")


def compute_labels(answers: pd.DataFrame) -> tuple[torch.Tensor, torch.Tensor]:
    labels = answers["judge_label"].map(LABEL_MAPPING).fillna(-1)
    labels = torch.tensor(labels.values, dtype=torch.long)

    invalid_answer = answers["stop_reason"] != VALID_ANSWER_STOP_REASON
    invalid_judge = answers["judge_finish_reason"] != VALID_JUDGE_FINISH_REASON

    if invalid_answer.sum() > 0:
        logger.warning(f"Setting {invalid_answer.sum()} labels to -1 (invalid answer stop_reason)")
    if invalid_judge.sum() > 0:
        logger.warning(f"Setting {invalid_judge.sum()} labels to -1 (invalid judge finish_reason)")

    labels[invalid_answer.values | invalid_judge.values] = -1

    return labels, labels >= 0


def compute_stats(answers: pd.DataFrame, labels: torch.Tensor, valid_mask: torch.Tensor) -> dict:
    n_total = len(labels)
    n_correct = (labels == 0).sum().item()
    n_incorrect = (labels == 1).sum().item()
    n_invalid = (labels == -1).sum().item()
    n_valid = valid_mask.sum().item()
    accuracy = n_correct / n_valid if n_valid > 0 else None

    return {
        "n_total": n_total,
        "n_correct": n_correct,
        "n_incorrect": n_incorrect,
        "n_invalid": n_invalid,
        "n_valid": n_valid,
        "accuracy": accuracy,
        "judge_labels": {k: int(v) for k, v in answers["judge_label"].value_counts().items()},
        "answer_stop_reasons": {
            k: int(v) for k, v in answers["stop_reason"].value_counts().items()
        },
        "judge_finish_reasons": {
            k: int(v) for k, v in answers["judge_finish_reason"].value_counts().items()
        },
    }


def print_stats(stats: dict) -> None:
    logger.info(f"Judge labels: {stats['judge_labels']}")
    logger.info(f"Answer stop_reason: {stats['answer_stop_reasons']}")
    logger.info(f"Judge finish_reason: {stats['judge_finish_reasons']}")
    logger.info(
        f"Labels: correct={stats['n_correct']}, incorrect={stats['n_incorrect']}, "
        f"invalid={stats['n_invalid']}, total={stats['n_total']}"
    )
    if stats["accuracy"] is not None:
        logger.info(f"LLM accuracy (valid only): {stats['accuracy']:.3f}")


if __name__ == "__main__":
    typer.run(main)
