import torch

import json
from templates import CUSTOM_DQA_TEMPLATE, CUSTOM_QA_TEMPLATE, DEFAULT_CHAT_TEMPLATE, CUSTOM_SFT_TEMPLATE
from alignment import apply_chat_template


downstream_data_path_map = dict({
    "MedQA":
        {
            "cn": 'data/MedQA/questions/Mainland/4_options/formatted/test.json',
            "en": 'data/MedQA/questions/US/4_options/formatted/phrases_no_exclude_test.json'
        },
    "OpenBookQA":
        {
            "cn": 'data/OpenBookQA/test.json',
            "en": 'data/OpenBookQA/test.json'
        },
    "yaojishi":
        {
            "cn": 'data/yaojishi/imp_dev2_new.json',
            "en": 'data/yaojishi/imp_dev2_new.json'
        },
    "wiki":
        {
            "cn": "data/wiki/formatted/nq-test.json",
            "en": "data/wiki/formatted/nq-test.json",
        }
})


class PredictDataset(torch.utils.data.Dataset):
    def __init__(self, task, data_path, tokenizer, few_shot_num, language, dataset):

        self.dataset = {}
        self.task = task
        self.language = language
        if isinstance(data_path, str):
            with open(data_path) as f:
                self.dataset = json.load(f)
        else:
            self.dataset = data_path

        if tokenizer.eos_token is None:
            tokenizer.eos_token = "<|endoftext|>"
            tokenizer.eos_token_id = tokenizer.eod_id
        if tokenizer.pad_token is None:
            # left_tokenizer.pad_token = '<PAD>'
            tokenizer.pad_token = tokenizer.eos_token
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id

        if "DQA" == task:
            print("using DQA template")
            tokenizer.chat_template = CUSTOM_DQA_TEMPLATE
        else:
            print("using QA template")
            tokenizer.chat_template = CUSTOM_QA_TEMPLATE
        # self.dataset = self.dataset[int(len(self.dataset) * frac):]
        # self.dataset = self.dataset[:100]
        self.datas = []

        self.tokenizer = tokenizer
        # self.user_token = self.tokenizer.convert_ids_to_tokens(195)
        # self.assistant_token = self.tokenizer.convert_ids_to_tokens(196)

        for da in self.dataset:

            if "DQA" in self.task:
                if isinstance(da, str):
                    da = dict({"text": da})
                da['messages'] = self.generate_prompt_DQA(da['text'], None)
            else:
                da['messages'] = self.generate_prompt_QA(da['query'], None)
            da['dataset'] = f"{dataset}_book_{self.language}"
            da = apply_chat_template(da, tokenizer=self.tokenizer, task="generation")
            self.datas.append(da)
        self.dataset_type = dataset
        # print(self.datas[0]['text'])

    def generate_prompt_DQA(self, query, history):
        if self.language == "cn":
            system_prompt = """"请根据给定的文章，构建一个与其内容紧密相关的<问题>。确保<问题>不直接提及该文章。可以在<问题>中加入特定的情境或背景，使得文章可以完整且准确地作为<问题>的答案，同时你需要生成一个针对该生成<问题>的<答案>，你可以参考文章的内容进行回答，但你的答案不能透露出你参考过这篇文章。输出请按照模版"<问题>：...\n<答案>：...。"""
        elif self.language == "en":
            system_prompt = """Please create a question that closely aligns with the provided article. Ensure that the <question> does not explicitly reference the text. You may incorporate specific scenarios or contexts in the <question>, allowing the <text> to serve as a comprehensive and precise answer, at the same time, you need to generate an <answer> for the generated <question>. You can refer to the content of the article to answer, but your answer cannot reveal that you have referred to this article. Please output according to the template '<question>:...\n<answer>:....'"""
        messages = [
            {
                "content": system_prompt,
                "role": "system"
            },
            {
                "content": query,
                "role": "user"
            }
        ]

        return messages

    def generate_prompt_QA(self, query, history):
        if "MedQA" in self.dataset or "yaojishi" in self.dataset:
            if self.language == 'cn':
                query_template = "你是HuatuoGPT模型，你具备丰富的医学知识，你需要直接回复用户提出的问题。"
            elif self.language == 'en':
                query_template = "You are HuatuoGPT, equipped with in-depth knowledge in medicine. Your task is to directly answer the user's question."
        else:
            if self.language == "cn":
                query_template = """你是KnowledgeGPT模型，你具备丰富的知识，你需要直接回复用户提出的问题。"""
            elif self.language == "en":
                query_template = """You are KnowledgeGPT, equipped with in-depth knowledge. Your task is to directly answer the user's question. """

        messages = [
            {
                "content": query_template,
                "role": "system"
            },
            {
                "content": query,
                "role": "user"
            },
        ]
        return messages

    def __getitem__(self, index):
        da = self.datas[index]
        return {
            'data': da,
            'input': da['text']
        }

    def __len__(self):
        return len(self.datas)

    def collate_fn(self, batch):
        batch_query = [x['input'] for x in batch]
        # print(batch_query,flush=True)
        batch_data = [x['data'] for x in batch]
        out_batch = {}
        out_batch['data'] = batch_data
        out_batch['text'] = batch_query
        # print("collate_fn")
        # print(batch)
        # output_tokenizer = self.tokenizer(batch_query, return_tensors='pt',max_length=2048, padding='longest')
        # output_tokenizer = self.tokenizer(batch_query, return_tensors='pt', padding='longest')
        output_tokenizer = self.tokenizer(batch_query, return_tensors='pt', padding='longest')
        # output_tokenizer = self.tokenizer(batch_query, return_tensors='pt',max_length=1024, padding='max_length')
        out_batch['input_ids'] = output_tokenizer['input_ids']
        out_batch['attention_mask'] = output_tokenizer['attention_mask']
        max_length = 2048
        if output_tokenizer['input_ids'].shape[-1] > max_length:
            out_batch['input_ids'] = out_batch['input_ids'][:, -max_length:]
            out_batch['attention_mask'] = out_batch['attention_mask'][:, -max_length:]

        # dubug
        # print(out_batch['input_ids'].shape,out_batch['attention_mask'].shape)
        # print(out_batch['input_ids'][0],out_batch['attention_mask'][0], self.tokenizer.decode(out_batch['input_ids'][0]))

        return out_batch


class EvaluateDataset(torch.utils.data.Dataset):
    def __init__(self, task, data_path, tokenizer, few_shot_num='', language='cn', dataset=''):

        self.dataset = {}
        self.task = task
        if isinstance(data_path, str):
            with open(data_path) as f:
                self.dataset = json.load(f)
        else:
            self.dataset = data_path
        self.language = language
        if tokenizer.eos_token is None:
            tokenizer.eos_token = "<|endoftext|>"
            tokenizer.eos_token_id = tokenizer.eod_id
        if tokenizer.pad_token is None:
            # left_tokenizer.pad_token = '<PAD>'
            tokenizer.pad_token = tokenizer.eos_token
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id

        if task == "MedQA":
            if language == "cn":
                self.query_prompt_without_ty = "请回答下面选择题。\n{question}\n{option_str}"
            elif language == "en":
                self.query_prompt_without_ty = "Please answer the following multiple choice questions. \n{question}\n{option_str}"
            else:
                NotImplementedError

            tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
        elif "DQA" == task:
            print("using DQA template")
            tokenizer.chat_template = CUSTOM_DQA_TEMPLATE
        elif "task" == "QA":
            print("using QA template")
            tokenizer.chat_template = CUSTOM_QA_TEMPLATE

        self.datas = []

        self.tokenizer = tokenizer
        # self.user_token = self.tokenizer.convert_ids_to_tokens(195)
        # self.assistant_token = self.tokenizer.convert_ids_to_tokens(196)

        for sou, sou_dataset in self.dataset.items():
            for da in sou_dataset:
                da['messages'] = self.generate_prompt_MedQA(da, None)
                da['dataset'] = sou
                da = apply_chat_template(da, tokenizer=self.tokenizer, task="generation")
                da['idx'] = len(self.datas)
                self.datas.append(da)
        self.dataset_type = dataset

    def generate_prompt_MedQA(self, data, history):
        def get_query(da):
            da['option_str'] = '\n'.join([f'{k}. {v}' for k, v in da['option'].items() if len(v) > 0])

            da['query'] = self.query_prompt_without_ty.format_map(da)
            return da['query']

        data['query'] = get_query(data)
        system_prompt = "你是HuatuoGPT模型，你具备丰富的医学知识，你需要直接回复用户提出的问题。"

        if self.language == 'cn':
            system_prompt = "你是HuatuoGPT模型，你具备丰富的医学知识，你需要直接回复用户提出的问题。"
            # system_prompt = "你是一名参加中国医学执业考试的医生。你需要展示对基础科学、临床科学、医学知识以及健康、疾病、患者护理和治疗模式的机制的理解。展示你应用于医学实践的知识的能力。对于以下的多项选择题，请从A到E中选择一个正确答案。基于医学指南中提及的当前和标准实践来选择你的答案。"
        elif self.language == 'en':
            system_prompt = "You are HuatuoGPT, equipped with in-depth knowledge in medicine. Your task is to directly answer the user's question."
            # system_prompt = '''You are a medical doctor taking the US Medical Licensing Examination. You need to demonstrate your understanding of basic and clinical science, medical knowledge, and mechanisms underlying health, disease, patient care, and modes of therapy. Show your ability to apply the knowledge essential for medical practice. For the following multiple-choice question, select one correct answer from A to D. Base your answer on the current and standard practices referenced in medical guidelines.'''
        # system_prompt = "你是HuatuoGPT模型，你具备丰富的医学知识，你需要回复用户提出的问题。先根据你掌握的知识针对每个选项先进行答题分析，然后再从A、B、C、D选项中选择一个作为正确答案"
        if self.task != "MedQA":
            system_prompt = """You are KnowledgeGPT, equipped with in-depth knowledge. Your task is to directly answer the user's question. """
        messages = [
            {
                "content": system_prompt,
                "role": "system"
            },
            {
                "content": data['query'],
                "role": "user"
            }
        ]

        return messages

    def generate_prompt_QA(self, query, history):
        query_template = """你是HuatuoGPT模型，你具备丰富的医学知识，你需要直接回复用户提出的问题。"""
        messages = [
            {
                "content": query_template,
                "role": "system"
            },
            {
                "content": query,
                "role": "user"
            },
        ]
        return messages

    def __getitem__(self, index):
        da = self.datas[index]
        return {
            'data': da,
            'input': da['text']
        }

    def __len__(self):
        return len(self.datas)

    def collate_fn(self, batch):
        batch_query = [x['input'] for x in batch]
        # print(batch_query,flush=True)
        batch_data = [x['data'] for x in batch]
        out_batch = {}
        out_batch['data'] = batch_data
        out_batch['text'] = batch_query
        output_tokenizer = self.tokenizer(batch_query, return_tensors='pt', padding='longest')
        out_batch['input_ids'] = output_tokenizer['input_ids']
        out_batch['attention_mask'] = output_tokenizer['attention_mask']
        max_length = 2048
        if output_tokenizer['input_ids'].shape[-1] > max_length:
            out_batch['input_ids'] = out_batch['input_ids'][:, -max_length:]
            out_batch['attention_mask'] = out_batch['attention_mask'][:, -max_length:]

        return out_batch


class SFTEvaluateDataset(torch.utils.data.Dataset):
    def __init__(self, task, data_path, tokenizer, few_shot_num='', language='cn'):

        self.dataset = {}
        self.task = task
        if isinstance(data_path, str):
            with open(data_path) as f:
                self.dataset = json.load(f)
        else:
            self.dataset = data_path
        self.language = language
        if tokenizer.eos_token is None:
            tokenizer.eos_token = "<|endoftext|>"
            tokenizer.eos_token_id = tokenizer.eod_id
        if tokenizer.pad_token is None:
            # left_tokenizer.pad_token = '<PAD>'
            tokenizer.pad_token = tokenizer.eos_token
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id

        if task == "MedQA":
            if language == "cn":
                self.query_prompt_without_ty = "请回答下面选择题。\n{question}\n{option_str}"
            elif language == "en":
                self.query_prompt_without_ty = "Please answer the following multiple choice questions. \n{question}\n{option_str}"
            else:
                NotImplementedError
            tokenizer.chat_template = CUSTOM_SFT_TEMPLATE
        elif "DQA" == task:
            print("using DQA template")
            tokenizer.chat_template = CUSTOM_DQA_TEMPLATE
        elif "task" == "QA":
            print("using QA template")
            tokenizer.chat_template = CUSTOM_QA_TEMPLATE

        self.datas = []

        self.tokenizer = tokenizer
        # self.user_token = self.tokenizer.convert_ids_to_tokens(195)
        # self.assistant_token = self.tokenizer.convert_ids_to_tokens(196)

        for sou, sou_dataset in self.dataset.items():
            for da in sou_dataset:
                da['messages'] = self.generate_prompt_MedQA(da, None)
                da['dataset'] = sou
                da = apply_chat_template(da, tokenizer=self.tokenizer, task="generation")
                da['idx'] = len(self.datas)
                self.datas.append(da)

    def generate_prompt_MedQA(self, data, history):

        option_str = '\n'.join([f'{k}. {v}' for k, v in data['option'].items() if len(v) > 0])
        data['query'] = "{}\n{}".format(data['question'], option_str)

        if self.language == 'cn':
            system_prompt = '''你是一名参加中国医学执业考试的医生。你需要展示对基础科学、临床科学、医学知识以及健康、疾病、患者护理和治疗模式的机制的理解。展示你应用于医学实践的知识的能力。对于以下的多项选择题，请从A到E中选择一个正确答案。基于医学指南中提及的当前和标准实践来选择你的答案。'''
        elif self.language == 'en':
            system_prompt = '''You are a medical doctor taking the US Medical Licensing Examination. You need to demonstrate your understanding of basic and clinical science, medical knowledge, and mechanisms underlying health, disease, patient care, and modes of therapy. Show your ability to apply the knowledge essential for medical practice. For the following multiple-choice question, select one correct answer from A to E. Base your answer on the current and standard practices referenced in medical guidelines.'''
        else:
            raise NotImplementedError

        # system_prompt = "你是HuatuoGPT模型，你具备丰富的医学知识，你需要回复用户提出的问题。先根据你掌握的知识针对每个选项先进行答题分析，然后再从A、B、C、D选项中选择一个作为正确答案"
        messages = [
            {
                "content": system_prompt,
                "role": "system"
            },
            {
                "content": data['query'],
                "role": "user"
            }
        ]

        return messages

    def __getitem__(self, index):
        da = self.datas[index]
        return {
            'data': da,
            'input': da['text']
        }

    def __len__(self):
        return len(self.datas)

    def collate_fn(self, batch):
        batch_query = [x['input'] for x in batch]
        # print(batch_query,flush=True)
        batch_data = [x['data'] for x in batch]
        out_batch = {}
        out_batch['data'] = batch_data
        out_batch['text'] = batch_query
        output_tokenizer = self.tokenizer(batch_query, return_tensors='pt', padding='longest')
        out_batch['input_ids'] = output_tokenizer['input_ids']
        out_batch['attention_mask'] = output_tokenizer['attention_mask']
        max_length = 2048
        if output_tokenizer['input_ids'].shape[-1] > max_length:
            out_batch['input_ids'] = out_batch['input_ids'][:, -max_length:]
            out_batch['attention_mask'] = out_batch['attention_mask'][:, -max_length:]

        return out_batch


