import time
from collections import defaultdict

import torch
from datasets import Dataset
from torch import Tensor
from tqdm.auto import tqdm
from transformers import PreTrainedModel, TokenizersBackend
from transformers.generation import GenerateDecoderOnlyOutput, GenerationConfig

from hallucinations.llm.activation_storage import ActivationStorage


@torch.inference_mode()
def predict_with_llm(
    model: PreTrainedModel,
    tokenizer: TokenizersBackend,
    dataset: Dataset,
    generation_config: GenerationConfig,
    activation_storage: ActivationStorage | None,
    batch_size: int,
    num_proc: int,
) -> dict[str, list[str]]:
    model.eval()
    dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

    results: dict[str, list[str]] = {
        "model_outputs": [],
        "stop_reason": [],
    }

    stop_reason_counter: dict[str, int] = defaultdict(lambda: 0)

    device = next(model.parameters()).device
    num_batches = (len(dataset) + batch_size - 1) // batch_size

    with tqdm(
        dataset.iter(batch_size=batch_size),
        total=num_batches,
        desc="Generating predictions",
    ) as pbar:
        for i, batch in enumerate(pbar):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            input_length = input_ids.size(1)

            start_time = time.time()
            outputs = model.generate(  # type: ignore
                inputs=input_ids,
                attention_mask=attention_mask,
                generation_config=generation_config,
            )
            duration = time.time() - start_time

            if isinstance(outputs, GenerateDecoderOnlyOutput):
                assert activation_storage is not None, (
                    "activation_storage must be provided for GenerateDecoderOnlyOutput"
                )
                outputs.sequences = outputs.sequences.cpu()  # type: ignore
                generated_ids = outputs.sequences
                token_masks = get_token_masks(outputs.sequences, tokenizer)
                activation_storage.update(
                    outputs=outputs,
                    attention_mask=attention_mask,
                    special_token_mask=token_masks["special_token_mask"],
                    decoder_added_token_mask=token_masks["decoder_added_token_mask"],
                    input_length=input_length,
                    batch_idx=i,
                )
                last_token_ids = outputs.sequences[:, -1].flatten()
                stop_reason = _get_stop_reason(last_token_ids, tokenizer.eos_token_id)
            elif isinstance(outputs, Tensor):
                generated_ids = outputs.cpu()  # type: ignore
                last_token_ids = outputs.sequences[:, -1].flatten()  # type: ignore
                stop_reason = _get_stop_reason(last_token_ids, tokenizer.eos_token_id)
            else:
                raise ValueError(f"Unexpected generation output: {type(outputs)}")

            decoded = tokenizer.batch_decode(
                generated_ids[:, input_length:],
                skip_special_tokens=True,
            )
            results["model_outputs"].extend(decoded)
            results["stop_reason"].extend(stop_reason)
            for reason in stop_reason:
                stop_reason_counter[reason] += 1

            stats = {
                "input_size": input_length,
                "throughput": f"{generated_ids.numel() / duration:0.2f} tok/sec",
                "mean(#special_tokens)": f"{(1 - attention_mask).float().mean().item():0.3f}",
                **{
                    f"#stop_reason({reason})": count
                    for reason, count in stop_reason_counter.items()
                },
            }
            pbar.set_postfix(stats)

            del outputs, input_ids, attention_mask, generated_ids
            torch.cuda.empty_cache()

    return results


def get_token_masks(token_ids: Tensor, tokenizer: TokenizersBackend) -> dict[str, Tensor]:
    special_token_masks = torch.tensor(
        [
            tokenizer.get_special_tokens_mask(
                seq_tok_ids,
                already_has_special_tokens=True,
            )
            for seq_tok_ids in token_ids
        ]
    )

    decoder_added_token_mask = torch.tensor(
        [
            [tok_id in tokenizer.added_tokens_decoder.keys() for tok_id in seq_token_ids]
            for seq_token_ids in token_ids
        ]
    )

    return {
        "special_token_mask": special_token_masks,
        "decoder_added_token_mask": decoder_added_token_mask,
    }


def _get_stop_reason(token_ids: Tensor, eos_token_id: int) -> list[str]:
    return ["eos_token" if token_id == eos_token_id else "max_length" for token_id in token_ids]
