from dataclasses import asdict
from typing import Any, Dict, List, Union

import torch.nn as nn
from evalscope.models.custom import CustomModel
from transformers import PreTrainedModel

from ..infer import PtEngine, RequestConfig
from ..template import InferRequest


class EvalModel(CustomModel):

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module],
        template,
        max_batch_size,
        model_name: str,
        **kwargs
    ) -> None:
        super().__init__(config={"model_id": model_name}, **kwargs)
        self.model_name = model_name
        self.model = model
        self.template = template
        self.engine = PtEngine.from_model_template(
            model, template, max_batch_size=max_batch_size
        )

    def predict(self, prompts: List[dict], **kwargs) -> List[Dict[str, Any]]:
        # use origin inputs
        infer_requests = self.prepare_inputs(kwargs.get("origin_inputs", prompts))

        infer_cfg = kwargs["infer_cfg"].copy()
        generation_config = RequestConfig(**infer_cfg)

        response = self.engine.infer(
            infer_requests=infer_requests,
            request_config=generation_config,
            use_tqdm=False,
        )
        dict_response = [asdict(item) for item in response]
        return dict_response

    def prepare_inputs(
        self, prompts: Union[List[dict], List[str]]
    ) -> List[InferRequest]:
        infer_requests = []
        for input_item in prompts:
            if isinstance(input_item, str):
                query = input_item
                system_prompt = None
            else:
                data: list = input_item["data"]
                if isinstance(data[0], tuple):  # for truthful_qa and hellaswag
                    query = "\n".join("".join(item) for item in data)
                    system_prompt = input_item.get("system_prompt", None)
                else:
                    query = data[0]
                    system_prompt = input_item.get("system_prompt", None)
            #  prepare messages
            messages = []
            if system_prompt:
                messages.append({"role": "system", "content": system_prompt})
            messages.append({"role": "user", "content": query})
            infer_requests.append(InferRequest(messages=messages))
        return infer_requests
