import logging

import pandas as pd
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizer
from transformers.data.data_collator import dataclass

from lib_llm.eval.metrics.metric import SequenceMetric, TokenValues
from lib_llm.eval.metrics.token_level import sequence_token_probs
from lib_llm.inference import iter_masked, predict, tokenize


# TODO: this is the old multiple choice-style probability computation code.
# It needs to be adapted to better fit into the new evaluation task
# framework.


logger = logging.getLogger(__name__)


@dataclass
class AnswerProbResult:
    probs: pd.DataFrame
    answer_token_ids: list[list[int]]
    answer_token_str: list[list[str]]


def answer_probs(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    model_type: str,
    prompts: list[str],
    answer_options: list[str],
    prompt_answer_separator: str = " ",
) -> AnswerProbResult:
    """Compute the probability of each answer option for each of the prompts.
    Sequence probabilities are the product of the individual token
    probabilities for each answer_option token sequence.

    Returns:
        A pandas Dataframe with an index of format (prompt, answer)
        and columns (probs, normed_probs) that denote
        the raw probability of the token sequence for each answer and
        the normalized probabilities, respectively.
        Probabilities are normalized such that the probabilities of all
        answers for each prompt sum up to 1.
    """
    # Encode all the answer options to determine the max length
    encoded_prompts = tokenize(
        tokenizer,
        prompts,
    )
    prompt_encoding_length = encoded_prompts.input_ids.shape[1]
    # Add a space to be consistent with the prompt format
    # TODO: pad the answer options to the same length as the full prompt
    # batches, else there will be an error
    # encoded_answer_options = tokenize(
    #     tokenizer, [f" {option.strip()}" for option in answer_options]
    # )
    encoded_answer_options = encode_answer_options(
        model_type,
        tokenizer,
        answer_options,
    )
    answer_option_encoding_length = encoded_answer_options.input_ids.shape[1]
    encoded_prompt_options = (
        tokenize(
            tokenizer,
            [
                f"{prompt}{prompt_answer_separator}{answer_option}"
                for answer_option in answer_options
            ],
            max_length=(
                prompt_encoding_length + answer_option_encoding_length + 1
            ),
        )
        for prompt in prompts
    )

    prompt_results = []
    for output in tqdm(
        predict(model, encoded_prompt_options),
        total=len(prompts),
    ):
        prompt_result = _compute_prompt_result(
            answer_options,
            encoded_answer_options,
            output["logits"],
        )
        prompt_results.append(prompt_result)
    prompt_answer_probs = pd.concat(prompt_results, axis=0, keys=prompts)
    prompt_answer_probs.index.names = ["prompt", "answer"]

    (
        masked_answer_token_ids,
        masked_answer_token_strings,
    ) = extract_masked_tokens(
        tokenizer,
        encoded_answer_options,
    )
    return AnswerProbResult(
        probs=prompt_answer_probs,
        answer_token_ids=masked_answer_token_ids,
        answer_token_str=masked_answer_token_strings,
    )


def _compute_prompt_result(
    answer_options: list[str],
    encoded_answer_options: BatchEncoding,
    logits: torch.Tensor,
) -> pd.DataFrame:
    answer_token_length = encoded_answer_options.input_ids.shape[1]
    # Use only the answer option encoding instead of the encoding
    # for the entire prompts here to get only the probabilities associated
    # with the answers, not the entire prompt.
    output_log_probs = F.log_softmax(logits, dim=-1)
    # Log probs over all potential tokens in the answer range
    answer_log_probs = TokenValues(
        values=output_log_probs[:, -answer_token_length:],
        mask=encoded_answer_options.attention_mask,
    )
    # Log probs of the tokens in the answer_options
    answer_token_log_probs = sequence_token_probs(
        answer_log_probs,
        encoded_answer_options,
    )

    # Compute the product of the probabilities of the tokens in each
    # answer option and the mean probability per token of the answer,
    # to normalize for different answer lengths
    answer_probs = [
        torch.exp(torch.sum(token_log_probs)).item()
        for token_log_probs in answer_token_log_probs.masked
    ]
    mean_answer_probs = [
        (torch.sum(torch.exp(token_log_probs)) / len(token_log_probs)).item()
        for token_log_probs in answer_token_log_probs.masked
    ]

    prompt_result = pd.DataFrame(
        {
            "prob_prod": answer_probs,
            "prob_mean": mean_answer_probs,
        },
        index=pd.Index(
            answer_options,
            name="answer",
        ),
    )
    return prompt_result


# TODO: impement this to be able to use the metric as a Trainer Callback
class AnswerProbMetric(SequenceMetric):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        prompts: list[str],
        answer_options: list[str],
    ) -> None:
        self.tokenizer = tokenizer
        self.prompts = prompts
        self.answer_options = answer_options

    def update(self, token_probs: TokenValues) -> None:
        raise NotImplementedError


def encode_answer_options(
    model_type: str,
    tokenizer: PreTrainedTokenizer,
    answer_options: list[str],
) -> BatchEncoding:
    match model_type:
        case "pythia":
            # Add a space in front, because options get encoded differently
            # depending on whether they are preceded by a space or not.
            # The input prompt uses a space, so we should be consistent.
            encoding = tokenize(
                tokenizer, [f" {option.strip()}" for option in answer_options]
            )
        case "llama2":
            # The Llama2 tokenizer adds some special tokens that we
            # must remove
            # expected_initial_token_text = ["<s>", "_"]
            expected_initial_token_ids = torch.tensor([1, 29871])
            encoding = tokenize(
                tokenizer, [f"{option.strip()}" for option in answer_options]
            )
            for token_ids in encoding.input_ids:
                init_token_ids = token_ids[:2]
                assert torch.all(
                    torch.eq(init_token_ids, expected_initial_token_ids)
                ), (
                    f"Tokens ids {init_token_ids} don't match "
                    f"{expected_initial_token_ids}"
                )
            encoding = BatchEncoding(
                {
                    "input_ids": encoding.input_ids[:, 2:],
                    "attention_mask": encoding.attention_mask[:, 2:],
                }
            )
        case _:
            raise NotImplementedError
    return encoding


def extract_masked_tokens(
    tokenizer: PreTrainedTokenizer,
    batch: BatchEncoding,
) -> tuple[list[list[int]], list[list[str]]]:
    masked_token_ids = [
        token_ids.tolist()
        for token_ids in iter_masked(batch.input_ids, batch.attention_mask)
    ]
    masked_token_strings = [
        tokenizer.convert_ids_to_tokens(token_ids)
        for token_ids in masked_token_ids
    ]
    return masked_token_ids, masked_token_strings


# def aggregated_sequence_probs(
#     token_probs: TokenValues, target_token_ids: torch.Tensor,
# ) -> pd.DataFrame:
#     """Given the probability distribution over tokens assigned to an input
#     batch by the model, this function computes the probability of
#     specific token sequences according to that distribution.
#     """
# if len(target_sequences.input_ids.shape) == 1:
#     # 1D sequence, replicate across the batch
#     input_ids = torch.stack(
#         [target_sequences.input_ids] * len(token_probs.values)
#     )
# else:
#     input_ids = target_sequences.input_ids
#     attention_mask = target_sequences.attention_mask
#     assert len(input_ids) == len(token_probs.values)
# sequence_log_probs = token_probs.token_log_probs.sum(dim=-1)
# mean_token_log_probs = token_probs.token_log_probs.mean(dim=-1)
# df = pd.DataFrame(
#     {
#         "log_probs": [
#             np.array([tp.cpu().item() for tp in sequence_token_log_probs])
#             for sequence_token_log_probs
# in token_probs.masked_token_log_probs
#         ],
#         "log_prob_sum": sequence_log_probs.tolist(),
#         "log_prob_mean": mean_token_log_probs.tolist(),
#     }
# )
# return df
