from typing import Any

import torch
from transformers import PreTrainedTokenizerBase
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput
from loguru import logger

from src.dataset.chat import ChatDatasetRecord
from src.generators.base import BaseGenerator
from src.settings.generators.chat import CustomChatGenerationSettings
from src.settings.generators.outputs.chat import AnswerMessage, ChatInferenceOutput
from src.settings.tf.generation import GeneratorTransformersSettings


class VLLMChatGenerator(BaseGenerator[ChatDatasetRecord, ChatInferenceOutput]):
    def __init__(
        self,
        transformers_settings: GeneratorTransformersSettings,
        custom_generation_settings: CustomChatGenerationSettings,
        model: LLM,
        tokenizer: PreTrainedTokenizerBase,
        batch: int,
        return_logits: bool = False,
        logprobs: int | None = None,
    ):
        model.set_tokenizer(tokenizer)
        super().__init__(model, tokenizer, batch=batch)

        beam_search_params: dict[str, Any] = {
            'best_of': transformers_settings.num_return_sequences,
            'use_beam_search': False,
        }
        if transformers_settings.num_beams > 1:
            beam_search_params['use_beam_search'] = True
            beam_search_params['best_of'] = transformers_settings.num_beams

        self._sampling_params = SamplingParams(
            n=transformers_settings.num_return_sequences,
            repetition_penalty=transformers_settings.repetition_penalty,
            temperature=transformers_settings.temperature,
            top_p=transformers_settings.top_p,
            top_k=transformers_settings.top_k,
            skip_special_tokens=custom_generation_settings.skip_special_tokens,
            stop=transformers_settings.stop_strings,
            stop_token_ids=transformers_settings.stop_token_ids,
            max_tokens=transformers_settings.max_new_tokens,
            logprobs=transformers_settings.logprobs,
            **beam_search_params,
        )

        self._return_logits = transformers_settings.return_logits

    def _generate_from_batch(
        self, records: list[dict[str, Any]], original_records: list[ChatDatasetRecord], dataset_name: str
    ) -> list[ChatInferenceOutput]:
        input_ids = [record['input_ids'].tolist() for record in records]
        logger.info(
            f"VLLM GENERATOR input_ids[0]: <|<|{self._tokenizer.decode(input_ids[0], skip_special_tokens=False)}|>|>"
        )
        request_outputs: list[RequestOutput] = self._model.generate(
            prompts=None,
            prompt_token_ids=input_ids,
            sampling_params=self._sampling_params,
        )

        outputs = []
        logger.info(
            f"VLLM GENERATOR request_outputs[0].outputs[0].token_ids: <|<|{self._tokenizer.decode(request_outputs[0].outputs[0].token_ids, skip_special_tokens=False)}|>|>"
        )
        for i, request_output in enumerate(request_outputs):
            original_record = original_records[i]
            answers = []
            for a in request_output.outputs:
                ans_msg = AnswerMessage(
                    id=str(a.index),
                    content=a.text,
                    sequence_score=a.cumulative_logprob,
                    logprobs=a.logprobs if a.logprobs else None,
                )
                if self._return_logits:
                    ans_msg.input_token_ids = torch.tensor(request_output.prompt_token_ids).unsqueeze(0)
                    ans_msg.answer_token_ids = torch.tensor(a.token_ids).unsqueeze(0)

                answers.append(ans_msg)

            outputs.append(
                ChatInferenceOutput(
                    id=original_record.id,
                    dataset_name=dataset_name,
                    messages=original_record.messages,
                    label=original_record.label,
                    answers=answers,
                )
            )
        return outputs
