# flake8: noqa
import json
import re

from . import dataset_loader


def extract_last_line(string):
    lines = string.split('\n')
    for item in lines[::-1]:
        if item.strip() != '':
            string = item
            break
    return string


def remove_few_shot_prefix(string: str):
    prefix_list = ['The answer is therefore', '答案是']
    for prefix in prefix_list:
        if string.startswith(prefix):
            string = string[len(prefix):].strip()
        elif prefix in string:
            index = string.rfind(prefix)
            if index >= 0:
                string = string[index + len(prefix):].strip()
    return string


def try_parse_few_shot_qa_single_answer(string, setting_name, language='en'):
    if setting_name == 'few-shot-CoT':
        string = extract_last_line(string)
    if language == 'en':
        pattern = 'answer is .*?([A-G])'
        match = re.search(pattern, string)
    elif language == 'zh':
        pattern = '答案是.*?([A-G])'
        match = re.search(pattern, string)
    else:
        raise ValueError('Unknown language {0}'.format(language))
    if match:
        return match.group(1)
    else:
        return None


def try_parse_few_shot_pattern(string: str, dataset_name, setting_name):
    if setting_name == 'few-shot-CoT':
        string = extract_last_line(string)
    if dataset_name in dataset_loader.chinese_cloze_datasets:
        return string.startswith('答案是')
    elif dataset_name in dataset_loader.english_cloze_datasets:
        return string.startswith('The answer is therefore')
    elif dataset_name in dataset_loader.chinese_qa_datasets:
        pattern = '答案是.*?([A-G])'
        match = re.search(pattern, string)
        return match is not None
    elif dataset_name in dataset_loader.english_qa_datasets:
        pattern = 'answer is .*?([A-G])'
        match = re.search(pattern, string)
        return match is not None
    return False


def parse_few_shot_qa_single_answer(string, setting_name, language='en'):
    answer = try_parse_few_shot_qa_single_answer(string, setting_name,
                                                 language)
    if answer is None:
        return find_first_capital_letter(string)
    else:
        return answer


def find_first_capital_letter(answer):
    letter_set = {'A', 'B', 'C', 'D', 'E', 'F'}
    for c in answer:
        if c in letter_set:
            return c
    # print("Can't find capital letter in:", answer)
    return ''


def extract_answer_in_bracket(answer, prefix='【', suffix='】'):
    if prefix not in answer and suffix not in answer:
        # print("doesn't found special tokens in:", answer)
        return ''
    s = answer.index(prefix) + len(prefix)
    t = answer.index(suffix)
    ret = answer[s:t]
    return ret


def parse_math_answer(setting_name, raw_string):
    if setting_name == 'few-shot-CoT':
        raw_string = extract_last_line(raw_string)
    if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
        raw_string = remove_few_shot_prefix(raw_string)
        return raw_string

    def remove_boxed(s):
        left = '\\boxed{'
        try:
            assert s[:len(left)] == left
            assert s[-1] == '}'
            answer = s[len(left):-1]
            if '=' in answer:
                answer = answer.split('=')[-1].lstrip(' ')
            return answer
        except:
            return None

    def last_boxed_only_string(string):
        idx = string.rfind('\\boxed')
        if idx < 0:
            idx = string.rfind('\\fbox')
            if idx < 0:
                return None
        i = idx
        right_brace_idx = None
        num_left_braces_open = 0
        while i < len(string):
            if string[i] == '{':
                num_left_braces_open += 1
            if string[i] == '}':
                num_left_braces_open -= 1
                if num_left_braces_open == 0:
                    right_brace_idx = i
                    break
            i += 1

        if right_brace_idx == None:
            retval = None
        else:
            retval = string[idx:right_brace_idx + 1]

        return retval

    def get_answer_with_dollar_sign(s):
        first_pattern = '\$(.*)\$'
        last_match = None
        matches = re.findall(first_pattern, s)
        if matches:
            last_match = matches[-1]
            if '=' in last_match:
                last_match = last_match.split('=')[-1].lstrip(' ')
        return last_match

    def get_answer_without_dollar_sign(s):
        last_match = None
        if '=' in s:
            last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
            if '\n' in last_match:
                last_match = last_match.split('\n')[0]
        else:
            pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
            matches = re.findall(pattern, s)
            if matches:
                last_match = matches[-1]
        return last_match

    raw_string = remove_few_shot_prefix(raw_string)
    if '\\boxed' in raw_string:
        answer = remove_boxed(last_boxed_only_string(raw_string))
    else:
        answer = get_answer_with_dollar_sign(raw_string)
        if not answer:
            answer = get_answer_without_dollar_sign(raw_string)
    return answer


def parse_qa_multiple_answer(string):
    # if setting_name == 'few-shot-CoT':
        # string = extract_last_line(string)
    for x in ['CC', 'CA', 'AC', 'POMES', 'AI', 'MIBG', 'CF', 'CTE', 'AD', 'CB', 'BG', 'BD', 'BE', 'BH', 'CTB', 'BI', 'CE', 'Pugh', 'Child', 'CTI', 'CTA', 'TACE', 'PPD', 'Castleman', 'BA', 'CH', 'AB', 'CTC', 'CT', 'CTH', 'CD', 'AH', 'AE', 'AA', 'AF', 'BC', 'CG', 'BB', 'CI', 'BF', 'CTF', 'CTG', 'AG', 'CTD', '分级C', '分级A', 'I131', '分级B', '分级D', '131I‐MIBG', 'NYHA', 'IPF', 'DIP', 'Lambert-Eaton', 'Graves', 'IIA期', 'CKD', 'FDA', 'A级', 'B级', 'C级', 'D级', '维生素D']:
        string = string.replace(x, '')
    pattern = '\(*([A-Z])\)*'
    match = re.findall(pattern, string)
    if match:
        return match
    return []


def post_process(dataset_name, setting_name, prediction):
    if dataset_name in dataset_loader.english_cloze_datasets or dataset_name in dataset_loader.chinese_cloze_datasets:
        return parse_math_answer(setting_name, prediction)

    if dataset_name in ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']:
        return parse_qa_multiple_answer(prediction, setting_name)

    # all other datasets are QA problems with single answer
    if 'zero-shot' in setting_name:
        answer = find_first_capital_letter(prediction)
        return answer

    # all other datasets are QA problems with single answer and setting_name are few-shot
    language = 'en' if dataset_name in dataset_loader.english_qa_datasets else 'zh'
    if dataset_name in dataset_loader.english_qa_datasets or dataset_name in dataset_loader.chinese_qa_datasets:
        return parse_few_shot_qa_single_answer(prediction, setting_name,
                                               language)
    else:
        raise ValueError(f'Unsupported dataset name {dataset_name}')
