from dataclasses import dataclass, field
from typing import Optional, List 

from omegaconf import DictConfig
from transformers import AutoTokenizer


@dataclass
class BaseQueryBuilder: 
    cfg: DictConfig
    task_desc: str  
    question_format: str
    answer_format: str
    sep: str
    shots: Optional[int] = 0
    question_examples: Optional[List[str]] = field(default_factory=list)
    answer_examples: Optional[List[str]] = field(default_factory=list)
    choices: Optional[List[str]] = field(default_factory=list) 
    choices_format: Optional[str] = field(default_factory=str)
    
    def __post_init__(self): 
        self.fewshot_examples = self.build_fewshot_str()
        inst = self.cfg.INST
        if self.cfg.replace_inst: 
            self.task_desc = inst
        else: 
            self.task_desc += inst

        if not self.cfg.name.startswith("gpt"):
            self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model, trust_remote_code=True)
        
    def format_q(self, question: str, n: int = 0) -> str: 
        return self.question_format.format(question=question) + n * self.sep

    def format_a(self, answer: str, n: int = 0) -> str: 
        return self.answer_format.format(answer=answer) + n * self.sep
    
    def format_choices(self, choices, n=0):
        formatted_choices = [self.choices_format.format(choice=self.choices[i], choice_string=choices[i]) for i in range(len(choices))]
        return self.sep.join(formatted_choices) + 2*self.sep + "Answer:"

    def build_fewshot_str(self): 
        if self.shots == 0: 
            return ""
        assert self.shots < len(self.question_examples), f"The # of shots ({self.shots}) is larger than the number of question examples ({len(self.question_examples)})."
        assert self.shots < len(self.answer_examples), f"The # of shots ({self.shots}) is larger than the number of answer examples ({len(self.answer_examples)})."
        print(f'Building fewshot string with {self.shots} examples.')
        fewshots = ""
        for question, answer in zip(self.question_examples[:self.shots], self.answer_examples[:self.shots]): 
            fewshots += self.format_q(question, n=1)
            fewshots += self.format_a(answer, n=2)
        return fewshots
    
    def build_queries(self, questions: List[str]) -> List[str]:
        return [self.build_query(question) for question in questions]
    
    def build_query_responses(self, questions: List[str], responses: List[str]) -> List[str]:
        qrs = [self.build_query_response(question, response, apply_chat_template=False) for question, response in zip(questions, responses)]
        
        return list(map(lambda x: x.strip(), self.tokenizer.apply_chat_template(qrs, tokenize=False, add_generation_prompt=False)))

    def build_query_response(self, question: str, response: str, apply_chat_template: bool = True, include_task_desc: bool = True) -> str: 
        messages = [
            {
                'role': 'user',
                'content': self._build_query_str(question, include_task_desc=include_task_desc),
            },
            {
                'role': 'assistant',
                'content': self._build_response_str(response),
            }
        ]

        if apply_chat_template: 
            return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False).strip()

        return messages

    def build_query(self, question: str, include_task_desc: bool = True, apply_chat_template: bool = True) -> str: 
        messages = [
            {
                'role': 'user',
                'content': self._build_query_str(question, include_task_desc=include_task_desc),
            }
        ]
        if apply_chat_template:
            return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        return messages
    
    def build_mc_query(self, question: str, choices: List[str]) -> str:
        messages = [{
            'role': 'user',
            'content': self._build_mc_query_str(question, choices),
            }]
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 

    def _build_query_str(self, question, include_task_desc: bool = True): 
        query = ""
        if include_task_desc:
            query += self.task_desc
            if self.task_desc != "":
                query += self.sep * 2 
        query += self.fewshot_examples 
        query += self.format_q(question, n=0)
        return query
    
    def _build_response_str(self, response): 
        return self.format_a(response) 
    
    def _build_mc_query_str(self, question, choices):
        query = self.task_desc
        if self.task_desc != "":
            query += self.sep * 2 
        query += self.fewshot_examples 
        query += self.format_q(question, n=1)
        query += self.format_choices(choices)
        return query 
