import logging
from dataclasses import dataclass, field
from typing import Callable

import numpy as np
import pandas as pd
import torch
from transformers import (
    BatchEncoding,
    PreTrainedModel,
    PreTrainedTokenizer,
    TrainerState,
)

from lib_dl_base.visualization.progress import conditional_tqdm
from lib_llm.eval.metrics.eval_tasks import SequenceEvaluationTask
from lib_llm.inference import PredictionConfig, predict


logger = logging.getLogger(__name__)

INDEX_NAMES = ["token_idx", "prefix_length"]


@dataclass
class PrefixEvalConfig:
    seed: int
    prefix_lengths: list[int] = field(
        default_factory=lambda: [
            1,
            2,
            3,
            4,
            5,
            8,
            10,
            15,
            16,
            20,
            32,
            64,
            128,
            256,
            512,
        ]
    )
    num_samples_per_prefix: int = 10
    max_token_samples: int = 1024
    # Add or remove tokens in the non-prefix part, relative value
    relative_context_size: float = 1.0


def eval_prefix_lengths(
    config: PrefixEvalConfig,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    sequence_token_ids: torch.Tensor,
    get_replacements: Callable[[torch.Tensor, int], torch.Tensor],
) -> pd.DataFrame:
    print("Evaluating prefix lengths")
    assert sequence_token_ids.ndim == 1, (
        "Expected sequence_token_ids to have shape (sequence_length, ), "
        f"but got {sequence_token_ids.shape}"
    )

    prediction_config = PredictionConfig(trim_last_token=False)
    token_indices = _get_probing_token_indices(
        len(sequence_token_ids),
        config.max_token_samples,
        config.seed,
    )
    token_results = []
    for token_idx in conditional_tqdm(
        token_indices,
        disabled_progress_interval=32,
    ):
        target_token_id = sequence_token_ids[token_idx]
        token_result = _eval_token(
            model,
            tokenizer,
            prediction_config=prediction_config,
            sequence_token_ids=sequence_token_ids,
            get_replacements=get_replacements,
            target_token_idx=token_idx,
            target_token_id=target_token_id,
            relative_size=config.relative_context_size,
            prefix_lengths=config.prefix_lengths,
        )
        token_results.append(token_result)
    performance = pd.concat(
        token_results,
        axis="index",
        keys=token_indices,
    )
    performance.index.names = INDEX_NAMES
    return performance


class PrefixEvalTask(SequenceEvaluationTask):
    def __init__(
        self,
        config: PrefixEvalConfig,
        data: tuple[list[list[str]], BatchEncoding],
        tokenizer: PreTrainedTokenizer,
        get_replacements: Callable[[torch.Tensor, int], torch.Tensor],
        eval_condition: Callable[[TrainerState], bool] = lambda _: True,
    ) -> None:
        super().__init__(
            metrics={},
            data=data,
            index_names=INDEX_NAMES,
            eval_condition=eval_condition,
        )
        assert len(data[0]) == 1, "Only one sequence can be evaluated"
        self.config = config
        self.tokenizer = tokenizer
        self.get_replacements = get_replacements

    def evaluate(self, model: PreTrainedModel) -> pd.DataFrame:
        return eval_prefix_lengths(
            config=self.config,
            model=model,
            tokenizer=self.tokenizer,
            sequence_token_ids=self.encoded_text.input_ids[0],
            get_replacements=self.get_replacements,
        )


def _get_probing_token_indices(
    sequence_length: int,
    max_token_samples: int,
    seed: int,
) -> np.ndarray:
    all_token_indices = np.arange(1, sequence_length)
    if sequence_length <= max_token_samples:
        return all_token_indices
    else:
        rng = np.random.default_rng(seed)
        return np.sort(
            rng.choice(
                all_token_indices,
                size=max_token_samples,
                replace=False,
            )
        )


def _eval_token(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    prediction_config: PredictionConfig,
    sequence_token_ids: torch.Tensor,
    get_replacements: Callable[[torch.Tensor, int], torch.Tensor],
    target_token_idx: int,
    target_token_id: torch.Tensor,
    relative_size: float,
    prefix_lengths: list[int],
) -> pd.DataFrame:
    samples = []
    target_idxs = []
    batch_starts = []
    prefix_lengths = _get_prefix_lengths(target_token_idx, prefix_lengths)
    # The maximum size we might grow the prompt to, in case we are adding
    # additional tokens
    max_context_size = max(
        target_token_idx, round(relative_size * target_token_idx)
    )
    for prefix_length in prefix_lengths:
        batch_starts.append(len(samples))
        prefix_start = target_token_idx - prefix_length
        prior_size = round(relative_size * prefix_start)
        replaced_prefix = sequence_token_ids[0:prefix_start]
        retained_prefix = sequence_token_ids[prefix_start:target_token_idx]

        # We add a dummy token for right padding here. Otherwise the
        # encoding function will add left padding, which will shift the
        # positions and use attention masks, which we want to avoid.
        padding_size = max_context_size - (prior_size + prefix_length)
        pad_token_id = tokenizer.pad_token_id
        assert pad_token_id is not None
        filler = torch.full(
            (padding_size,),
            pad_token_id,
            dtype=torch.int64,
            device=sequence_token_ids.device,
        )

        if prefix_start == 0:
            # We only use the prefix, no other tokens in the context
            padded_prefix = torch.cat((retained_prefix, filler), dim=0).detach()
            samples.append(padded_prefix)
            target_idxs.append(len(retained_prefix) - 1)
        else:
            # There are other tokens in the context, create replacement
            # samples
            replacements = get_replacements(replaced_prefix, prior_size)
            assert replacements.shape[1] == prior_size, (
                f"Replacements have shape {replacements.shape}, "
                "along axis 1, but "
                f"expected {prior_size}"
            )
            for replacement in replacements:
                replacement = replacement.to(retained_prefix.device)
                sample = torch.cat(
                    (replacement, retained_prefix, filler), dim=0
                ).detach()
                samples.append(sample)
                sample_length = prior_size + prefix_length
                target_idxs.append(sample_length - 1)
    assert all(len(sample) == max_context_size for sample in samples)

    encoding = BatchEncoding(
        {
            "input_ids": torch.stack(samples).long(),
            "attention_mask": torch.ones(
                len(samples), max_context_size, dtype=torch.long
            ),
        }
    )
    output = predict(model, encoding, prediction_config)
    target_logits = torch.stack(
        [
            (output["logits"][i, target_idx].float().detach().cpu())
            for i, target_idx in enumerate(target_idxs)
        ]
    )

    result = _aggregate_predictions(
        logits=target_logits,
        batch_starts=batch_starts,
        target_token_id=target_token_id.cpu(),
    )
    result.index = prefix_lengths
    return result


def _get_prefix_lengths(
    token_idx: int,
    prefix_lengths: list[int],
) -> list[int]:
    """Filter the prefix lengths to only include those that are smaller
    than the target token index. Always include the target token index
    itself.

    Args:
        token_idx (int): The index of the target token
        prefix_lengths (list[int]): The list of prefix lengths to filter

    Returns:
        list[int]: The filtered list of prefix lengths, including the
            target token index
    """
    filtered_prefix_lengths = [
        prefix_length
        for prefix_length in prefix_lengths
        if prefix_length < token_idx
    ]
    if token_idx not in filtered_prefix_lengths:
        # Always check the full prefix
        filtered_prefix_lengths.append(token_idx)
    return filtered_prefix_lengths


def _aggregate_predictions(
    logits: torch.Tensor,
    batch_starts: list[int],
    target_token_id: torch.Tensor,
) -> pd.DataFrame:
    """Aggregate the predictions for different prefix lengths."""
    predictions_correct = _is_prediction_correct(
        logits,
        target_token_id,
    ).numpy()

    # Average the correctness and probabilities over the different
    # replacement samples, i.e. compute one mean value for each
    # prefix length
    avg_correct = []
    top_token_ids = []
    for start, end in zip(batch_starts, batch_starts[1:] + [len(logits)]):
        avg_correct.append(predictions_correct[start:end].mean())

        # top_token_id = compute_top_token_ids(logits[start:end])
        top_token_id = _compute_plurality_prediction(logits[start:end])
        top_token_ids.append(top_token_id)

    return pd.DataFrame(
        {
            "correct_samples": avg_correct,
            "plurality_correct": (
                bool(tid == target_token_id.item()) for tid in top_token_ids
            ),
        },
    )


def _is_prediction_correct(
    logits: torch.Tensor, target_token_ids: torch.Tensor
) -> torch.Tensor:
    """Check if the top token id in the logits is the same as the target
    token id.
    """
    top_token_ids = torch.argmax(logits, dim=-1)
    predictions_correct = top_token_ids == target_token_ids
    return predictions_correct


def _compute_plurality_prediction(logits: torch.Tensor) -> torch.Tensor:
    """
    Returns the token_id that appears most frequently as the top prediction.
    In case of a tie, returns the smallest token_id among the tied ones.
    """
    max_prob_token_ids = torch.argmax(logits, dim=-1)
    unique_vals, counts = torch.unique(
        max_prob_token_ids,
        return_counts=True,
        dim=-1,
    )
    max_counts = torch.max(counts, dim=-1)
    plurality = unique_vals[counts == max_counts.values].min(dim=-1)
    return plurality.values
