import torch
from omegaconf import OmegaConf
from typing import List, Optional, Tuple, Any

def create_template(template_config : dict, tokenizer=None, max_len=200) -> Any:
    template_config = OmegaConf.create(template_config)
    template_name = template_config.pop("class_name")
    template_config['max_len'] = max_len
    template = globals()[template_name](
        **template_config,
        tokenizer=tokenizer,
    )
    return template


class ToFUTemplate:
    def __init__(self, question_start_token, question_end_token, answer_token, tokenizer=None, max_len=200, prefix="") -> None:
        self.question_start_token = question_start_token
        self.question_end_token = question_end_token
        self.answer_token = answer_token
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.prefix = prefix #! This is the prefix for the model, use this for the prompting baselines 
    
    def prepare_prompt(self, question, answer) -> str:
        #! The format must follow the tofu implementaiton
        new_question = self.question_start_token + " " + question + " " + self.question_end_token
        new_answer = self.answer_token + answer
        full_text = new_question + new_answer
        return full_text

    def prepare_gen_prompt(self, question, answer) -> str:
        #! The format must follow the tofu implementaiton
        new_question = self.question_start_token + " " + question + " " + self.question_end_token
        full_text = new_question + self.answer_token
        return full_text

    def prepare_input_ids(self, question, answer, tokenizer=None, max_len=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if self.tokenizer is None and tokenizer is None:
            raise ValueError("Tokenizer is not provided")

        tokenizer = self.tokenizer if tokenizer is None else tokenizer
        max_len = self.max_len if max_len is None else max_len

        #! Important about the format
        new_question = self.question_start_token + " " + question + " " + self.question_end_token
        new_answer = self.answer_token + answer
        full_text = new_question + new_answer
        num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))
   
        encoded = tokenizer(
            full_text, 
            add_special_tokens=True, 
            max_length=max_len, 
            truncation=True, 
        )
        pad_length = max_len - len(encoded.input_ids)
        pad_input_ids = encoded['input_ids'] + [tokenizer.eos_token_id] * pad_length
        pad_attention_mask = encoded['attention_mask'] + [0] * pad_length
        if len(encoded.input_ids) == max_len:
            label = encoded.input_ids
        else:
            label = encoded['input_ids'] + [tokenizer.eos_token_id] + [-100] * (pad_length-1)

        # change label to -100 for question tokens
        label = torch.tensor(label)
        label[:num_question_tokens] = -100
        return (
            torch.tensor(pad_input_ids),
            label,
            torch.tensor(pad_attention_mask),
        )

# dataset_path: cais/wmdp
# group: wmdp
# test_split: test
# training_split: null
# validation_split: null
# num_fewshot: 0
# output_type: multiple_choice
# doc_to_text: "{{question.strip()}}\nA. {{choices[0]}}\nB. {{choices[1]}}\nC. {{choices[2]}}\nD. {{choices[3]}}\nAnswer:"
# doc_to_choice: ["A", "B", "C", "D"]
# doc_to_target: answer
# metric_list:
#   - metric: acc
#     aggregation: mean
#     higher_is_better: true
# metadata:
#   version: 0


class FullTemplate:
    def __init__(self, question_start_token, question_end_token, answer_token, **kwargs) -> None:
        self.question_start_token = question_start_token
        self.question_end_token = question_end_token
        self.answer_token = answer_token
    
    def prepare_gen_prompt(self, question=None, answer=None, **kwargs):
        if question != None:
            return (
                self.question_start_token 
                + question.strip() 
                + self.question_end_token
                + self.answer_token
            ).strip()
        elif kwargs.get('prefix', None) is not None:
            prefix = kwargs.get('prefix', None)
            return (
                prefix.strip()
                + " "
            )
        elif kwargs.get('text', None) is not None:
            return ""
        else:
            print(kwargs.keys())
            raise ValueError("Unkown")

    def prepare_prompt(self, question=None, answer=None, **kwargs):
        if question != None:
            return (
                self.question_start_token 
                + question.strip() 
                + self.question_end_token
                + self.answer_token
                + answer.strip()
            )
        elif kwargs.get('prefix', None) is not None:
            prefix = kwargs.get('prefix', None)
            continu = kwargs.get('continuation',None)
            return (
                prefix.strip()
                + " "
                + continu
            )
        elif kwargs.get('text', None) is not None:
            return kwargs.get('text')
        else:
            print(kwargs.keys())
            raise ValueError("Unkown")
            
    def prepare_batch_prompt(self, question, choices, **kwargs):
        return [
            self.prepare_prompt(question=question, answer=choice) for choice in choices
        ]