from collections import Counter, defaultdict
from math import log2
from pathlib import Path

import torch
from torch import Tensor
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase

from hallucinations.data.factory import get_dataset_from_dir
from hallucinations.dirs import DatasetDir
from hallucinations.utils.misc import load_json


def compute_baseline_features(
    ds_dir: DatasetDir,
    tokenizer: PreTrainedTokenizerBase,
    valid_label_idx: Tensor,
) -> dict[str, Tensor]:
    answers = _load_answers(ds_dir)
    questions = _load_questions(ds_dir)
    answer_token_info = _load_answer_token_indices(ds_dir)
    input_lengths = _load_input_lengths(ds_dir)

    features: dict[str, list[float]] = defaultdict(list)

    for idx in tqdm(valid_label_idx.tolist(), desc="Computing baseline features"):
        answer = answers[idx]["prediction"]
        question = questions[idx]
        token_info = answer_token_info[idx] if answer_token_info else None
        input_length = int(input_lengths[idx].item()) if input_lengths is not None else None

        length_feats = _compute_length_features(answer, question, tokenizer)
        lexical_feats = _compute_lexical_features(answer, question, tokenizer)
        numeric_feats = _compute_numeric_features(answer)
        punct_feats = _compute_punctuation_features(answer)
        position_feats = _compute_position_features(token_info, input_length, tokenizer, answer)

        for k, v in length_feats.items():
            features[k].append(v)
        for k, v in lexical_feats.items():
            features[k].append(v)
        for k, v in numeric_feats.items():
            features[k].append(v)
        for k, v in punct_feats.items():
            features[k].append(v)
        for k, v in position_feats.items():
            features[k].append(v)

    return {k: torch.tensor(v, dtype=torch.float32) for k, v in features.items()}


def load_or_compute_baseline_features(
    ds_dir: DatasetDir,
    tokenizer: PreTrainedTokenizerBase,
    valid_label_idx: Tensor,
    cache_file: Path | None = None,
) -> dict[str, Tensor]:
    if cache_file is not None and cache_file.exists():
        return torch.load(cache_file, weights_only=True)

    features = compute_baseline_features(ds_dir, tokenizer, valid_label_idx)

    if cache_file is not None:
        torch.save(features, cache_file)

    return features


def _load_answers(ds_dir: DatasetDir) -> list[dict]:
    answers = load_json(ds_dir.answers_file)
    assert isinstance(answers, list)
    return answers


def _load_questions(ds_dir: DatasetDir) -> list[str]:
    dataset = get_dataset_from_dir(ds_dir, split=None)
    return dataset["question"]


def _load_answer_token_indices(ds_dir: DatasetDir) -> list[dict] | None:
    if not ds_dir.answer_token_indices_file.exists():
        return None
    data = load_json(ds_dir.answer_token_indices_file)
    return data["results"]


def _load_input_lengths(ds_dir: DatasetDir) -> Tensor | None:
    if not ds_dir.internal_states_metrics_file.exists():
        return None
    metrics = torch.load(ds_dir.internal_states_metrics_file, weights_only=True)
    return metrics["input_lengths"]


def _compute_length_features(
    answer: str,
    question: str,
    tokenizer: PreTrainedTokenizerBase,
) -> dict[str, float]:
    answer_tokens = tokenizer.encode(answer, add_special_tokens=False)
    question_tokens = tokenizer.encode(question, add_special_tokens=False)

    return {
        "length_answer_tokens": float(len(answer_tokens)),
        "length_question_tokens": float(len(question_tokens)),
        "length_total_tokens": float(len(answer_tokens) + len(question_tokens)),
    }


def _compute_lexical_features(
    answer: str,
    question: str,
    tokenizer: PreTrainedTokenizerBase,
) -> dict[str, float]:
    answer_tokens = tokenizer.encode(answer, add_special_tokens=False)
    question_tokens = tokenizer.encode(question, add_special_tokens=False)

    return {
        "lexical_answer_entropy": _shannon_entropy(answer_tokens),
        "lexical_answer_vocab_diversity": _vocab_diversity(answer_tokens),
        "lexical_answer_repetition_rate": _bigram_repetition_rate(answer_tokens),
        "lexical_question_entropy": _shannon_entropy(question_tokens),
        "lexical_question_vocab_diversity": _vocab_diversity(question_tokens),
        "lexical_question_repetition_rate": _bigram_repetition_rate(question_tokens),
    }


def _shannon_entropy(tokens: list[int]) -> float:
    if len(tokens) == 0:
        return 0.0
    counts = Counter(tokens)
    total = len(tokens)
    entropy = 0.0
    for count in counts.values():
        p = count / total
        if p > 0:
            entropy -= p * log2(p)
    return entropy


def _vocab_diversity(tokens: list[int]) -> float:
    if len(tokens) == 0:
        return 0.0
    return len(set(tokens)) / len(tokens)


def _bigram_repetition_rate(tokens: list[int]) -> float:
    if len(tokens) < 2:
        return 0.0
    bigrams = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
    counts = Counter(bigrams)
    repeated = sum(1 for c in counts.values() if c > 1)
    return repeated / len(bigrams)


def _compute_numeric_features(answer: str) -> dict[str, float]:
    digit_count = sum(int(c.isdigit()) for c in answer)
    return {
        "numeric_has_numbers": float(digit_count > 0),
        "numeric_digit_count": float(digit_count),
    }


def _compute_punctuation_features(answer: str) -> dict[str, float]:
    special_chars = set("!@#$%^&*()[]{}|\\:;<>,/~`")
    return {
        "punctuation_period_count": float(answer.count(".")),
        "punctuation_question_mark_count": float(answer.count("?")),
        "punctuation_special_char_count": float(sum(1 for c in answer if c in special_chars)),
    }


def _compute_position_features(
    token_info: dict | None,
    input_length: int | None,
    tokenizer: PreTrainedTokenizerBase,
    answer: str,
) -> dict[str, float]:
    if token_info is None or input_length is None:
        return {"position_relative_answer": 0.5}

    token_indices = token_info.get("token_indices", [])
    if not token_indices:
        return {"position_relative_answer": 0.5}

    first_answer_idx = token_indices[0]
    gen_token_idx = first_answer_idx - input_length

    answer_tokens = tokenizer.encode(answer, add_special_tokens=False)
    gen_length = len(answer_tokens)

    if gen_length == 0:
        return {"position_relative_answer": 0.5}

    relative_position = gen_token_idx / gen_length
    relative_position = max(0.0, min(1.0, relative_position))

    return {"position_relative_answer": relative_position}


BASELINE_FEATURE_CATEGORIES = {
    "length": [
        "length_answer_tokens",
        "length_question_tokens",
        "length_total_tokens",
    ],
    "lexical": [
        "lexical_answer_entropy",
        "lexical_answer_vocab_diversity",
        "lexical_answer_repetition_rate",
        "lexical_question_entropy",
        "lexical_question_vocab_diversity",
        "lexical_question_repetition_rate",
    ],
    "numeric": [
        "numeric_has_numbers",
        "numeric_digit_count",
    ],
    "punctuation": [
        "punctuation_period_count",
        "punctuation_question_mark_count",
        "punctuation_special_char_count",
    ],
    "position": [
        "position_relative_answer",
    ],
}
