import json
import os

root_path = "/root/workspace/self-improvement"
def get_datasets(dataset_name, dataset_type, source_path="", task="generation", init_train=False):
    if source_path != "":
        with open(source_path) as fp:
            dataset = json.load(fp)
            return dataset
    if dataset_type == "dqa":
        if task == "sft":
            if init_train:
                if "MedQA_en" in dataset_name:
                    path = "data/MedQA/en_rewrite_MedQA_10000-422054_selected_samples_dqa_gpt-4-turbo-2024-04-09.json"
                elif "OpenBookQA" in dataset_name:
                    path = "data/OpenBookQA/en_rewrite_openbook_book_cn_dqa_gpt-4-0125-preview.json"
                elif "pile_freelaw" in dataset_name:
                    path = "data/pile_freelaw/en_rewrite_pile_freelaw_10000-198931_selected_samples_dqa_gpt-4-turbo-2024-04-09.json"
                elif "finance_alpaca" in dataset_name:
                    path = "data/finance_alpaca/en_rewrite_finance_alpaca_10000-68071_selected_samples_dqa_gpt-4-turbo-2024-04-09.json"
                else:
                    raise NotImplementedError
            else:
                if "yaojishi" in dataset_name:
                    path = "data/yaojishi/yaojishi_book_cn_less512.json"
                elif "MedQA_cn" in dataset_name:
                    path = "data/MedQA/MedQA_book_cn_less512.json"
                elif "OpenBookQA" in dataset_name:
                    path = "data/OpenBookQA/OpenBookQA_book_en.json"
                elif "MedQA_en" in dataset_name:
                    path = "data/MedQA/MedQA_book_en_less512.json"

        else:
            if "yaojishi" in dataset_name:
                path = "data/yaojishi/yaojishi_book_cn_less512.json"
            elif "MedQA_cn" in dataset_name:
                path = "data/MedQA/MedQA_book_cn_less512.json"
            elif "OpenBookQA" in dataset_name:
                path = "data/OpenBookQA/OpenBookQA_book_en.json"
            elif "MedQA_en" in dataset_name:
                path = "data/MedQA/MedQA_book_en_less512.json"
            else:
                raise NotImplementedError
        with open(path) as fp:
            book = json.load(fp)
        return book
    elif dataset_type == "eval":
        if dataset_name == "MedQA_cn":
            path = os.path.join(root_path, 'data/MedQA/questions/Mainland/4_options/formatted/test.json')
        elif dataset_name == "MedQA_en":
            path = os.path.join(root_path, 'data/MedQA/questions/US/4_options/formatted/phrases_no_exclude_test.json')
        elif dataset_name == "yaojishi":
            path = os.path.join(root_path, 'data/yaojishi/test.json')
        elif dataset_name == "OpenBookQA":
            path = os.path.join(root_path, 'data/OpenBookQA/test.json')
        elif dataset_name == "pile_freelaw":
            path = os.path.join(root_path, 'data/lighteval/lexglue/test.json')
        elif dataset_name == "finance_alpaca":
            path = os.path.join(root_path, 'data/AdaptLLM/finance-tasks/test_full_q.json')
        else:
            raise NotImplementedError

        with open(path, encoding='utf-8') as fp:
            dataset = json.load(fp)
        eval_dataset = []
        for k, v in dataset.items():
            for i in v:
                i['idx'] = len(eval_dataset)
                i['dataset'] = k
                i['query'] = i['question']
                eval_dataset.append(i)
        return eval_dataset
    else:
        raise NotImplementedError


def apply_chat_template(args, tokenizer, data, dataset_type, task="generation"):
    if isinstance(args, dict):
        prompt_templates = args
    else:
        prompt_templates = args.prompt_templates
    if "query" in data and "question" not in data:
        data['question'] = data['query']
        del data['query']
    # if "question" not in data:
    #     return False
    if dataset_type == "dqa":
        user_prompt = prompt_templates['dqa'].replace("***book***", data['text'])
        conversation = [{"role": "user", "content": user_prompt}]
        if task == "sft":
            answer_template = """<question>: {}\n<answer>: {}"""
            if "answer" in data:
                answer_key = "answer"
            elif "dqa_answer" in data:
                answer_key = "dqa_answer"
            elif "response" in data:
                answer_key = "response"

            qa_prompt = answer_template.format(data['question'], data[answer_key])
            conversation.append({"role": "assistant", "content": qa_prompt})

    elif dataset_type == "qa":
        if "question" in data:
            question_key = "question"
        elif "question" in data:
            question_key = "question"
        user_prompt = prompt_templates['qa'].replace("***question***", data[question_key])
        conversation = [{"role": "user", "content": user_prompt}]
        if task == "sft":
            if "answer" in data:
                answer_key = "answer"
            elif "qa_answer" in data:
                answer_key = "qa_answer"
            elif "response" in data:
                answer_key = "response"
            conversation.append({"role": "assistant", "content": data[answer_key]})
    elif dataset_type == "eval":
        query_prompt_without_ty = "Please answer the following multiple choice questions. \n{}\n{}"
        data['option_str'] = '\n'.join([f'{k}. {v}' for k, v in data['option'].items() if len(v) > 0])
        user_prompt = query_prompt_without_ty.format(data["question"], data["option_str"])
        conversation = [{"role": "user", "content": user_prompt}]
    else:
        raise NotImplementedError
    if "text" in data:
        data['ori_text'] = data['text']
    prompt = tokenizer.apply_chat_template(
        conversation=conversation,
        tokenize=False,
        add_generation_prompt=True if task == "generation" else False,
    )
    data['text'] = prompt
    data['conversation'] = conversation
    return data