import pandas as pd
from ...utils import can_infer, track_progress_rich
from ...smp import *
import numpy as np
import re

MMB_abbrs = {
}

MMT_abbrs = {
}




def report_acc(df):
    # assert group in [None, 'category', 'l2-category']
    res = defaultdict(list)

    if 'split' in df:
        splits = list(set(df['split']))
        res['split'] = splits
    else:
        df['split'] = ['none'] * len(df)
        res['split'] = ['none']

    for group in [None, 'l2-category', 'category']:
        if group is None:
            res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
        elif group not in df:
            continue
        else:
            abilities = list(set(df[group]))
            abilities.sort()
            for ab in abilities:
                ab_name = MMB_abbrs[ab] if ab in MMB_abbrs else ab
                sub_df = df[df[group] == ab]
                res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
    return pd.DataFrame(res)


def report_acc_MMT(df):
    # assert group in [None, 'category', 'l2-category']
    res = defaultdict(list)
    res['split'] = list()
    res['Overall'] = list()
    for _, name in MMT_abbrs.items():
        res[name] = list()

    if 'split' in df:
        splits = list(set(df['split']))
        res['split'] = splits

    else:
        df['split'] = ['none'] * len(df)
        res['split'] = ['none']

    for group in [None, 'category', 'l2-category']:
        if group is None:
            res['Overall'] = [np.mean(df[df['split'] == sp]['hit']) for sp in res['split']]
            res['Overall'].extend([np.mean(df['hit'])])
        elif group not in df:
            continue
        elif group == 'category':
            abilities = list(set(df[group]))
            abilities.sort()
            for ab in abilities:
                ab_name = ab
                sub_df = df[df[group] == ab]
                res[ab_name] = [np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']]
                res[ab_name].extend([np.mean(sub_df['hit'])])
        else:
            abilities = list(set(df[group]))
            abilities.sort()
            for ab in abilities:
                sub_task_name_list = df[df['l2-category'] == ab]['category'].unique()
                sub_task_acc = []
                for sub_task_name in sub_task_name_list:
                    sub_df = df[df['category'] == sub_task_name]
                    sub_task_acc.append([np.mean(sub_df[sub_df['split'] == sp]['hit']) for sp in res['split']])

                new_acc = []
                for i in range(len(sub_task_acc[0])):
                    new_acc.append(sum([_[i] for _ in sub_task_acc]) / len([_ for _ in sub_task_acc]))
                ab_name = MMT_abbrs[ab] if ab in MMT_abbrs else ab
                res[ab_name] = new_acc

                sub_task_acc = []
                for sub_task_name in sub_task_name_list:
                    sub_df = df[df['category'] == sub_task_name]
                    sub_task_acc.append([np.mean(sub_df['hit'])])
                new_acc = []
                for i in range(len(sub_task_acc[0])):
                    new_acc.append(sum([_[i] for _ in sub_task_acc]) / len([_ for _ in sub_task_acc]))

                res[ab_name].extend(new_acc)

    res['split'].append('ALL')
    return pd.DataFrame(res)


def report_acc_MMSci(df):

    df_filtered = df[df['setting'].isin(['Fig2Cap', 'SubFig2Cap', 'SubCap2Fig'])]

    subject_acc = df_filtered.groupby(['subject', 'setting'])['hit'].mean().unstack(fill_value=0)
    subject_acc['Avg'] = subject_acc.mean(axis=1)
    subject_acc.reset_index(inplace=True)

    category_acc = df_filtered.groupby(['category', 'setting'])['hit'].mean().unstack(fill_value=0)
    category_acc['Avg'] = category_acc.mean(axis=1)
    category_acc.reset_index(inplace=True)
    category_acc['category'] = 'CATEGORY_' + category_acc['category']
    category_acc.rename(columns={'category': 'subject'}, inplace=True)

    overall_acc = df_filtered.groupby(['setting'])['hit'].mean().to_frame().T
    overall_acc['Avg'] = overall_acc.mean(axis=1)
    overall_acc['subject'] = 'Overall'

    full_acc_df = pd.concat([subject_acc, category_acc, overall_acc], ignore_index=True)
    column_order = ['subject', 'Fig2Cap', 'SubFig2Cap', 'SubCap2Fig', 'Avg']
    full_acc_df = full_acc_df[column_order]
    return full_acc_df


def build_prompt(question, options, prediction):
    tmpl = (
        'You are an AI assistant who will help me to match '
        'an answer with several options of a single-choice question. '
        'You are provided with a question, several options, and an answer, '
        'and you need to find which option is most similar to the answer. '
        'If the meaning of all options are significantly different from the answer, output Z. '
        'Your should output a single uppercase character in A, B, C, D (if they are valid options), and Z. \n'
        'Example 1: \n'
        'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n'
        'Answer: a cute teddy bear\nYour output: A\n'
        'Example 2: \n'
        'Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n'
        'Answer: Spider\nYour output: Z\n'
        'Example 3: \n'
        'Question: {}?\nOptions: {}\nAnswer: {}\nYour output: '
    )
    return tmpl.format(question, options, prediction)


def build_prompt_cn(question, options, prediction):
    tmpl = (
        '你是一个帮助我匹配答案与单选题中多个选项的 AI 助手。'
        '你会被提供：一个问题，多个选项，一个答案。你的任务是找到与答案意义最相近的选项。'
        '如果所有选项的意义都与答案显著不同，则输出 Z。'
        '你应该输出一个单个的大写字母，例如 A, B, C, D（如果它们是有效选项），或 Z。'
        '例 1:'
        '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 一只可爱的泰迪熊\n输出: A\n'
        '例 2: \n'
        '问题: 图中最主要的物体是什么?\n选项: A. 泰迪熊 B. 兔子 C. 猫 D. 狗\n答案: 蜘蛛\n输出: Z\n'
        '例 3: \n'
        '问题: {}?\n选项: {}\n答案: {}\n输出: '
    )
    return tmpl.format(question, options, prediction)



def build_choices(item):
    ret = {}
    for ch in string.ascii_uppercase:
        if ch in item and (not pd.isna(item[ch])):
            ret[ch] = item[ch]
    return ret


def prefetch_answer(item):
    choices = build_choices(item)
    return can_infer(item['prediction'], choices)


def extract_answer_from_item(model, item, dataset_name=None):
    logger = get_logger('Evaluation')
    # It will return: (pred, raw, llm_time)
    choices = build_choices(item)
    option_str = build_option_str(choices)

    prompt = build_prompt(item['question'], option_str, item['prediction'])
    retry = 3

    ret = can_infer(item['prediction'], choices)
    if ret:
        return dict(opt=ret, log=item['prediction'])
    if model is None:
        return dict(opt='Z', log='Failed in Prefetch, no GPT-based answer matching under `exact_matching` policy.')

    while retry:
        ans = model.generate(prompt)
        if 'Failed to obtain answer via API' in ans:
            logger.warning('GPT API failed to answer. ')
        else:
            ret = can_infer(ans, choices)
            if ret:
                return dict(opt=ret, log=ans)
            else:
                logger.warning(
                    f'Failed to in infer: prediction is {ans}, choice labels are {set(choices)}'
                    f', Answer is {item["answer"]}' if "answer" in item else ""
                )
        retry -= 1

        if retry == 0:
            options = list(choices) + ['Z'] if 'Z' not in choices else []
            return dict(opt=rd.choice(options), log='Failed to predict, thus randomly generate one. ')


# For Circular Evaluation
def prefetch_circular_group(sub_data, verbose=False):
    lt = len(sub_data)
    GT, PRED = [], []
    for i in range(lt):
        item = sub_data.iloc[i]
        GT.append(item['GT'])
        PRED.append(prefetch_answer(item))
        if PRED[-1] and (GT[-1] != PRED[-1]):
            log = (
                f'Failed in Prefetching Rolling {i}: Answer is {GT[-1]}, '
                f"Prediction is {item['prediction']}, Pre-fetched is {PRED[-1]}. "
            )
            return dict(hit=0, log=log)
    flag = True
    for g, p in zip(GT, PRED):
        if g != p:
            flag = False
    ret = (dict(hit=1, log='Succeed During Pre-fetching'), ) if flag else (None, )
    ret = ret + (GT, PRED) if verbose else ret
    return ret if len(ret) > 1 else ret[0]


def eval_vanilla(model, item, dataset_name=None):
    res = extract_answer_from_item(model, item, dataset_name=dataset_name)
    opt, match_log = res['opt'], res['log']
    if opt == item['GT']:
        return dict(hit=1, log=f'Match Log: {match_log}. ')
    else:
        return dict(hit=0, log=f'Match Log: {match_log}. ')


# For Circular Evaluation
def eval_circular_group(model, sub_data, dataset_name=None):
    prefetched = prefetch_circular_group(sub_data, verbose=True)
    if isinstance(prefetched, dict) and 'hit' in prefetched:
        return prefetched

    res, GT, PRED = prefetch_circular_group(sub_data, verbose=True)
    if res is not None:
        return res

    lt = len(sub_data)
    log = ''
    for i in range(lt):
        if PRED[i]:
            log += f'Rolling {i} Matched.\n'
        else:
            res = extract_answer_from_item(model, sub_data.iloc[i], dataset_name=dataset_name)
            opt, match_log = res['opt'], res['log']
            PRED[i] = opt
            if PRED[i] != GT[i]:
                log += (
                    f"Failed in Rolling {i}: Answer is {GT[i]}; Prediction is {sub_data.iloc[i]['prediction']}; "
                    f'Pre-fetched is {PRED[i]}; Match Log is {match_log}.\n'
                )
                return dict(hit=0, log=log)
            else:
                log += (
                    f"Rolling {i}: Answer is {GT[i]}, Prediction is {sub_data.iloc[i]['prediction']}, "
                    f'Pre-fetched is {PRED[i]}.\n'
                )

    return dict(hit=1, log=log)


# data, meta are pd.DataFrame, result_file is a path
def mcq_vanilla_eval(model, data, meta, nproc, result_file, dataset_name=None):
    result = {}
    if osp.exists(result_file):
        result = load(result_file)
    answer_map = {i: c for i, c in zip(meta['index'], meta['answer'])}


    data = data[data['index'].isin(answer_map)]
    data['GT'] = [answer_map[idx] for idx in data['index']]
    items = []

    for i in range(len(data)):
        # Dealing with the normal part
        item = data.iloc[i]
        if item['index'] not in result:
            items.append(item)

    tups = [dict(model=model, item=x, dataset_name=dataset_name) for x in items]
    keys = [x['index'] for x in items]
    if len(tups):
        res = track_progress_rich(eval_vanilla, tups, nproc=nproc, chunksize=nproc, save=result_file, keys=keys)
        result = load(result_file)
        for k, v in zip(keys, res):
            if k not in result:
                result[k] = v
    data['hit'] = [result[i]['hit'] for i in data['index']]
    data['log'] = [result[i]['log'] for i in data['index']]
    if 'GT' in data:
        data.pop('GT')
    return data


# data, meta are pd.DataFrame, result_file is a path
def mcq_circular_eval(model, data, meta, nproc, result_file, dataset_name=None):
    result = {}
    if osp.exists(result_file):
        result = load(result_file)
    # Build Answer Map
    answer_map = {i: c for i, c in zip(meta['index'], meta['answer'])}

    for idx in list(meta['index']) + list(data['index']):
        assert istype(idx, int)
    if 'g_index' not in data:
        data['g_index'] = [int(x % 1e6) for x in data['index']]

    # Only keep those lines in the meta data
    data = data[data['index'].isin(answer_map)]
    data['GT'] = [answer_map[idx] for idx in data['index']]

    data['tmp_flag'] = [x == y for x, y in zip(data['index'], data['g_index'])]
    data_main = data[data['tmp_flag']]
    data_main.pop('tmp_flag')

    data_groups = []
    for i in range(len(data_main)):
        # Dealing with the normal part
        idx = data_main.iloc[i]['index']
        if idx not in result:
            sub_data = data[data['g_index'] == idx]
            data_groups.append(sub_data)

    if len(data_groups):
        prefetched = [prefetch_circular_group(g, verbose=False) for g in data_groups]
        remain = []
        for dg, pf in zip(data_groups, prefetched):
            if pf is not None:
                result[dg.iloc[0]['g_index']] = pf
            else:
                remain.append(dg)
        dump(result, result_file)

        tups = [dict(model=model, sub_data=x, dataset_name=dataset_name) for x in remain]
        keys = [x.iloc[0]['g_index'] for x in remain]

        if len(tups) == 0:
            pass
        elif model is None:
            logger = get_logger('Evaluation')
            logger.warning('Exact Matching mode, will not do GPT-based answer matching. ')
            for k in keys:
                result[k] = dict(
                    hit=0, log='Failed in Prefetch, no GPT-based answer matching under `exact_matching` policy.')
        else:
            res = track_progress_rich(
                eval_circular_group,
                tups,
                nproc=nproc,
                chunksize=nproc,
                save=result_file,
                keys=keys)
            result = load(result_file)
            for k, v in zip(keys, res):
                if k not in result:
                    result[k] = v

    tmp_pth = f'/tmp/{timestr()}.xlsx'
    dump(data_main, tmp_pth)
    data_main = load(tmp_pth)
    indices = data_main['index']
    data_main['hit'] = [result[i]['hit'] for i in indices]
    data_main['log'] = [result[i]['log'] for i in indices]
    if 'GT' in data_main:
        data_main.pop('GT')

    return data_main


def extract_characters_regex(s, choices=['(A)', '(B)', '(C)', '(D)', '(E)']):
    if type(s) is dict:
        s = ''
    s = s.strip()
    answer_prefixes = [
        'The best answer is',
        'The correct answer is',
        'The answer is',
        'The answer',
        'The best option is'
        'The correct option is',
        'Best answer:'
        'Best option:',
    ]
    for answer_prefix in answer_prefixes:
        s = s.replace(answer_prefix, '')

    if len(s.split()) > 10 and not re.search('[ABCDE]', s):
        return ''
    matches = re.search(r'[ABCDE]', s)
    if matches is None:
        for choice in choices:
            if s.lower() in choice.lower():
                return choice[1]
        return ''
    return matches[0]


def get_dimension_rating(data_path):
    TASKS = [
        'Reasoning',
        'Perception',
    ]

    SUBTASKS = [
        'Monitoring',
        'Autonomous_Driving',
        'OCR with Complex Context',
        'Diagram and Table',
        'Remote Sensing',
    ]
    data = load(data_path)
    results = {}
    results['Overall'] = {}
    for task in TASKS:
        results[f'{task}'] = {}
        for subtask in SUBTASKS:
            results[f'{task}'][f'{subtask}'] = {}

    for i in range(len(data)):
        question = data.iloc[i]
        Task = question['category'].split('/')[0]
        Subtask = question['category'].split('/')[1]
        Category = question['l2-category'].lower()
        if 'attribute' in Category.lower():
            Category = Category.split('/')[0] + '/attribute'
        if question['score'] >= 0:
            cnt = question['score']
            if Category not in results[Task][Subtask].keys():
                results[Task][Subtask][f'{Category}'] = {'true': cnt, 'false': 1 - cnt}
            else:
                results[Task][Subtask][f'{Category}']['true'] += cnt
                results[Task][Subtask][f'{Category}']['false'] += 1 - cnt

    sum_all, succ_all = 0, 0
    for task, tasks_values in results.items():
        cnt_task, sum_task = 0, 0
        for substask, subtask_value in tasks_values.items():
            cnt_subtask, sum_subtask = 0, 0
            for category, category_dict in subtask_value.items():
                cnt_subtask += category_dict['true']
                sum_subtask += category_dict['false'] + category_dict['true']
                acc = category_dict['true'] / (category_dict['false'] + category_dict['true'])
                results[task][substask][category] = acc
            if sum_subtask == 0:
                acc_subtasks = 0
            else:
                acc_subtasks = cnt_subtask / sum_subtask
            cnt_task += cnt_subtask
            sum_task += sum_subtask
            results[task][substask]['Avg'] = acc_subtasks
        if sum_task == 0:
            acc_task = 0
        else:
            acc_task = cnt_task / sum_task
        succ_all += cnt_task
        sum_all += sum_task
        results[task]['Avg'] = acc_task
    results['Overall'] = succ_all / sum_all
    return results
