# flake8: noqa
import ast
import json
import os

import pandas as pd
import tiktoken
from tqdm import tqdm

from .constructions import ChatGPTSchema, ResultsForHumanSchema
from .utils import extract_answer, read_jsonl, save_jsonl

# define the datasets
english_qa_datasets = [
    'lsat-ar', 'lsat-lr', 'lsat-rc', 'logiqa-en', 'sat-math', 'sat-en',
    'aqua-rat', 'sat-en-without-passage', 'gaokao-english'
]
chinese_qa_datasets = [
    'logiqa-zh', 'jec-qa-kd', 'jec-qa-ca', 'gaokao-chinese',
    'gaokao-geography', 'gaokao-history', 'gaokao-biology', 'gaokao-chemistry',
    'gaokao-physics', 'gaokao-mathqa'
]
english_cloze_datasets = ['math']
chinese_cloze_datasets = ['gaokao-mathcloze']

multi_choice_datasets = ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']
math_output_datasets = ['gaokao-mathcloze', 'math']


def convert_zero_shot(line, dataset_name):
    try:
        passage = line['passage'] if line['passage'] is not None else ''
        if dataset_name in english_qa_datasets:
            option_string = 'ABCDEFG'
            count = len(line['options'])
            if count == 1:
                count = 5
            return passage + 'Q: '  + line['question'] + ' ' \
                + 'Answer Choices: ' + ' '.join(line['options']) + '\n' + \
                'A: Among A through {}, the answer is'.format(option_string[count - 1])

        elif dataset_name in chinese_qa_datasets:
            option_string = 'ABCDEFG'
            count = len(line['options'])
            if count == 1:
                count = 4
            return passage + '问题：' + line['question'] + ' ' \
                + '选项：' + ' '.join(line['options']) + '\n' + \
                '答案：从A到{}, 我们应选择'.format(option_string[count - 1])

        elif dataset_name in english_cloze_datasets:
            return passage + 'Q: ' + line['question'] + '\n' \
                                              'A: The answer is'

        elif dataset_name in chinese_cloze_datasets:
            return passage + '问题：' + line['question'] + '\n' \
                                                '答案：'
    except NameError:
        print('Dataset not defined.')


prefix = '该问题为单选题，所有选项中必有一个正确答案，且只有一个正确答案。\n'


def convert_zero_shot_CoT_stage1(line, dataset_name):
    try:
        passage = line['passage'] if line['passage'] is not None else ''
        if dataset_name in english_qa_datasets:
            return passage + 'Q: ' + line['question'] + ' ' \
                + 'Answer Choices: ' + ' '.join(line['options']) + '\n' + \
                "Let's think step by step."

        elif dataset_name in chinese_qa_datasets:
            option_string = 'ABCDEFG'
            count = len(line['options'])
            if count == 1:
                count = 4
            return passage + '问题：' + line['question'] + ' ' \
                + '选项：' + ' '.join(line['options']) + '\n' + \
                '从A到{}, 我们应选择什么？让我们逐步思考：'.format(option_string[count - 1])

        elif dataset_name in english_cloze_datasets:
            return passage + 'Q: ' + line['question'] + '\n' \
                                              "A: Let's think step by step."

        elif dataset_name in chinese_cloze_datasets:
            return passage + '问题：' + line['question'] + '\n' \
                                                '答案：让我们逐步思考：'
    except NameError:
        print('Dataset not defined.')


# process few-shot raw_prompts
def combine_prompt(prompt_path,
                   dataset_name,
                   load_explanation=True,
                   chat_mode=False):
    skip_passage = False
    if dataset_name == 'sat-en-without-passage':
        skip_passage = True
        dataset_name = 'sat-en'
    demostrations = []
    # read the prompts by context and explanation
    context_row = [0, 1, 3, 5, 7, 9]
    explanation_row = [0, 2, 4, 6, 8, 10]
    raw_prompts_context = pd.read_csv(prompt_path,
                                      header=0,
                                      skiprows=lambda x: x not in context_row,
                                      keep_default_na=False)
    raw_prompts_explanation = pd.read_csv(
        prompt_path,
        header=0,
        skiprows=lambda x: x not in explanation_row,
        keep_default_na=False).replace(r'\n\n', '\n', regex=True)
    contexts = []
    for line in list(raw_prompts_context[dataset_name]):
        if line:
            # print(line)
            contexts.append(ast.literal_eval(line))
    explanations = [
        exp for exp in raw_prompts_explanation[dataset_name] if exp
    ]

    for idx, (con, exp) in enumerate(zip(contexts, explanations)):
        passage = con['passage'] if con[
            'passage'] is not None and not skip_passage else ''
        question = con['question']
        options = con['options'] if con['options'] is not None else ''
        label = con['label'] if con['label'] is not None else ''
        answer = con[
            'answer'] if 'answer' in con and con['answer'] is not None else ''

        if dataset_name in english_qa_datasets:
            question_input = 'Problem {}.   '.format(idx + 1) + passage + ' ' + question + '\n' \
                              + 'Choose from the following options:    ' + ' '.join(options) + '\n'
            question_output = (('Explanation for Problem {}:   '.format(idx + 1) + exp + '\n') if load_explanation else '') \
                              + 'The answer is therefore {}'.format(label)

        elif dataset_name in chinese_qa_datasets:
            question_input = '问题 {}.   '.format(idx + 1) + passage + ' ' + question + '\n' \
                              + '从以下选项中选择:    ' + ' '.join(options) + '\n'
            question_output = (('问题 {}的解析:   '.format(idx + 1) + exp + '\n') if load_explanation else '') \
                              + '答案是 {}'.format(label)

        elif dataset_name in english_cloze_datasets:
            question_input = 'Problem {}.   '.format(idx + 1) + question + '\n'
            question_output = (('Explanation for Problem {}:   '.format(idx + 1) + exp + '\n') if load_explanation else '') \
                              + 'The answer is therefore {}'.format(answer)

        elif dataset_name in chinese_cloze_datasets:
            question_input = '问题 {}.   '.format(idx + 1) + question + '\n'
            question_output = (('问题 {}的解析:   '.format(idx + 1) + exp + '\n') if load_explanation else '') \
                              + '答案是 {}'.format(answer)
        else:
            raise ValueError(
                f'During loading few-sot examples, found unknown dataset: {dataset_name}'
            )
        if chat_mode:
            demostrations.append((question_input, question_output))
        else:
            demostrations.append(question_input + question_output + '\n')

    return demostrations


enc = None


def _lazy_load_enc():
    global enc
    if enc is None:
        enc = tiktoken.encoding_for_model('gpt-4')


# cut prompt if reach max token length
def concat_prompt(demos,
                  dataset_name,
                  max_tokens,
                  end_of_example='\n',
                  verbose=False):
    _lazy_load_enc()
    demostration_en = 'Here are the answers for the problems in the exam.\n'
    demostration_zh = '以下是考试中各个问题的答案。\n'

    for i in range(len(demos)):
        # print(len(enc.encode(demostration_en)), len(enc.encode(demostration_zh)))
        if dataset_name in english_qa_datasets:
            demostration_en = demostration_en + demos[i] + end_of_example
        elif dataset_name in chinese_qa_datasets:
            demostration_zh = demostration_zh + demos[i] + end_of_example
        elif dataset_name in english_cloze_datasets:
            demostration_en = demostration_en + demos[i] + end_of_example
        elif dataset_name in chinese_cloze_datasets:
            demostration_zh = demostration_zh + demos[i] + end_of_example
        # break if reach max token limit
        if len(enc.encode(demostration_en)) < max_tokens and len(
                enc.encode(demostration_zh)) < max_tokens:
            output = demostration_en if len(demostration_en) > len(
                demostration_zh) else demostration_zh
            prompt_num = i + 1
        else:
            break
    if verbose:
        print('max_tokens set as ', max_tokens, 'actual_tokens is',
              len(enc.encode(output)), 'num_shot is', prompt_num)
    return output, prompt_num


def concat_prompt_chat_mode(demos,
                            dataset_name,
                            max_tokens,
                            end_of_example='\n',
                            verbose=False):
    _lazy_load_enc()
    answers = []
    sentences = ''
    for i in range(len(demos)):
        answers += [
            {
                'role': 'user',
                'content': demos[i][0]
            },
            {
                'role': 'assistant',
                'content': demos[i][1]
            },
        ]
        sentences += json.dumps(answers[-1])
        # break if reach max token limit
        if len(enc.encode(sentences)) > max_tokens:
            answers.pop()
            answers.pop()
            break
    if verbose:
        print('max_tokens set as ', max_tokens, 'actual_tokens is',
              len(enc.encode(sentences)), 'num_shot is',
              len(answers) // 2)
    return answers, len(answers) // 2


def convert_few_shot(line, dataset_name, demo, n_shot, chat_mode=False):
    passage = line['passage'] if line['passage'] is not None else ''
    question = line['question']
    options = line['options'] if line['options'] is not None else ''

    if dataset_name in english_qa_datasets:
        question_input = 'Problem {}.   '.format(n_shot + 1) + passage + ' ' + question + '\n' \
            + 'Choose from the following options:    ' + ' '.join(options) + '\n'
        # + "Explanation for Problem {}:   ".format(n_shot + 1)

    if dataset_name in chinese_qa_datasets:
        question_input = '问题 {}.   '.format(n_shot + 1) + passage + ' ' + question + '\n' \
            + '从以下选项中选择:    ' + ' '.join(options) + '\n'
        # + "问题 {}的解析:   ".format(n_shot + 1)

    if dataset_name in english_cloze_datasets:
        question_input = 'Problem {}.   '.format(n_shot + 1) + question + '\n'
        # + "Explanation for Problem {}:   ".format(n_shot + 1)

    if dataset_name in chinese_cloze_datasets:
        question_input = '问题 {}.   '.format(n_shot + 1) + question + '\n'
        # + "问题 {}的解析:   ".format(n_shot + 1)
    if chat_mode:
        return demo + [
            {
                'role': 'user',
                'content': question_input
            },
        ]
    else:
        return demo + question_input


def load_dataset(dataset_name,
                 setting_name,
                 parent_path,
                 prompt_path=None,
                 max_tokens=None,
                 end_of_example='\n',
                 chat_mode=False,
                 verbose=False):
    test_path = os.path.join(parent_path, dataset_name + '.jsonl')
    loaded_jsonl = read_jsonl(test_path)
    processed = []
    if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
        # process demo once if it is few-shot-CoT
        processed_demos = combine_prompt(
            prompt_path,
            dataset_name,
            load_explanation=setting_name == 'few-shot-CoT',
            chat_mode=chat_mode)
        if chat_mode:
            chosen_prompt, n_shot = concat_prompt_chat_mode(processed_demos,
                                                            dataset_name,
                                                            max_tokens,
                                                            end_of_example,
                                                            verbose=verbose)
        else:
            chosen_prompt, n_shot = concat_prompt(processed_demos,
                                                  dataset_name,
                                                  max_tokens,
                                                  end_of_example,
                                                  verbose=verbose)
    if verbose:
        loaded_jsonl = tqdm(loaded_jsonl)
    for meta_idx, line in enumerate(loaded_jsonl):
        if setting_name == 'zero-shot':
            ctxt = convert_zero_shot(line, dataset_name)
        elif setting_name == 'zero-shot-CoT':
            ctxt = convert_zero_shot_CoT_stage1(line, dataset_name)
        elif setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
            ctxt = convert_few_shot(line, dataset_name, chosen_prompt, n_shot,
                                    chat_mode)
        try:
            new_instance = ChatGPTSchema(context=ctxt, metadata=meta_idx)
            processed.append(new_instance.to_dict())
        except NameError:
            print('Dataset not defined.')
    return processed


def generate_second_stage_input(dataset_name,
                                input_list,
                                output_list,
                                with_format_prompt=False):
    try:
        english_format_prompt = 'Based on the previous results, your task is to extract the final answer and provide the output enclosed in brackets【】, such as 【0】 or 【A】.'
        chinese_format_prompt = '根据以上内容，你的任务是把最终的答案提取出来并填在【】中，例如【0】或者【A】。'
        if dataset_name in english_qa_datasets:
            prompt_suffix = 'Therefore, among A through E, the answer is'
            if with_format_prompt:
                prompt_suffix = english_format_prompt + prompt_suffix
        elif dataset_name in chinese_qa_datasets:
            prompt_suffix = '因此，从A到D, 我们应选择'
            if with_format_prompt:
                prompt_suffix = chinese_format_prompt + prompt_suffix
        elif dataset_name in english_cloze_datasets:
            prompt_suffix = 'Therefore, the answer is'
            if with_format_prompt:
                prompt_suffix = english_format_prompt + prompt_suffix
        elif dataset_name in chinese_cloze_datasets:
            prompt_suffix = '因此，答案是'
            if with_format_prompt:
                prompt_suffix = chinese_format_prompt + prompt_suffix
    except NameError:
        print('Dataset not defined.')
    processed = []
    for i in range(len(input_list)):
        ctxt = '{0}\n{1}\n{2}'.format(input_list[i]['context'],
                                      extract_answer(output_list[i]),
                                      prompt_suffix)
        new_instance = ChatGPTSchema(context=ctxt,
                                     metadata=input_list[i]['metadata'])
        processed.append(new_instance.to_dict())
    return processed


def load_dataset_as_result_schema(dataset_name, parent_path):
    test_path = os.path.join(parent_path, dataset_name + '.jsonl')
    loaded_jsonl = read_jsonl(test_path)

    processed = []
    for i, line in enumerate(loaded_jsonl):
        problem_input = convert_zero_shot(line, dataset_name)
        processed.append(
            ResultsForHumanSchema(
                index=i,
                problem_input=problem_input,
                label=line['label'] if line['label'] else line['answer'],
            ))
    return processed


if __name__ == '__main__':

    # set variables
    parent_dir = '../../data/V1_1/'
    raw_prompt_path = '../data/few_shot_prompts.csv'

    # set dataset name to process
    setting_name = 'few-shot-CoT'  # setting_name can be chosen from ["zero-shot", "zero-shot-CoT", "few-shot-CoT"]
    data_name = 'jec-qa-kd'
    save_dir = '../../experiment_input/{}/'.format(setting_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    processed_data = load_dataset(data_name,
                                  setting_name,
                                  parent_dir,
                                  prompt_path=raw_prompt_path,
                                  max_tokens=2048)
    save_jsonl(processed_data,
               os.path.join(save_dir, '{}.jsonl'.format(data_name)))
