import string
import os
import json
import orjsonl
from collections import OrderedDict, defaultdict
from datasets import load_dataset, load_from_disk, DatasetDict
import copy
from copy import deepcopy
import numpy as np
import pandas as pd

# read with orjsonl (jsonl in, list out)
def read_with_orjsonl(file_path):
    data = orjsonl.load(file_path)
    return data

def build_choices(item):
    ret = {}
    for ch in string.ascii_uppercase:
        if ch in item and (not pd.isna(item[ch])):
            ret[ch] = item[ch]
    
    if ret == {}:
        if 'options' in item:
            options = item['options']
        elif 'choices' in item:
            options = item['choices']
        else:
            raise ValueError('No options or choices found in the item.')
        
        if isinstance(options, str):
            options = eval(options)
        
        for i, ch in enumerate(string.ascii_uppercase[:len(options)]):
            ret[ch] = options[i]

    return ret

def can_infer_option(prediction, choices, verbose=False):
    # choices = string.ascii_uppercase[:num_choice]
    
    bard_err = [
        "Sorry, I can't help with images of people yet.",
        "I can't process this file."
    ]
    for err in bard_err:
        if err in prediction:
            return 'Z'

    # def count(splits, choices=string.ascii_uppercase, prefix='', suffix=''):
    def count_choice(splits, choices, prefix='', suffix=''):
        cnt = 0
        for c in choices:
            if prefix + c + suffix in splits:
                cnt += 1
        return cnt

    prediction_mod = deepcopy(prediction)
    chars = '.()[],:;!*#{}'
    for c in chars:
        prediction_mod = prediction_mod.replace(c, ' ')
        
    splits = [x.strip() for x in prediction_mod.split()]
    count = count_choice(splits, choices)
    
    if count == 1:
        for ch in choices:
            if 'A' in splits and len(splits) > 3 and verbose:
                print(f'A might be a quantifier in the string: {prediction}.')
                return False
            if ch in splits:
                return ch
    elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
        return 'Z'
    
    return False

def can_infer_text(prediction, choices):
    prediction = prediction.lower()
    assert isinstance(choices, dict)
    for k in choices:
        assert k in string.ascii_uppercase
        choices[k] = str(choices[k]).lower()
    cands = []
    for k in choices:
        if choices[k] in prediction:
            cands.append(k)
    if len(cands) == 1:
        return cands[0]
    return False

def can_infer(prediction, choices):
    prediction = str(prediction)
    copt = can_infer_option(prediction, choices)
    return copt if copt else can_infer_text(prediction, choices)

def prefetch_answer(prediction):
    choices = build_choices(prediction)
    return can_infer(prediction['response_post'], choices)

def prefetch_acc(predictions, key="category", answer_key='answer'):
    tot = defaultdict(lambda: 0)
    match = defaultdict(lambda: 0)
    hit = defaultdict(lambda: 0)
    for i in range(len(predictions)):
        item = predictions.iloc[i]
        cate = item[key]
        tot['Overall'] += 1
        tot[cate] += 1
        matched = prefetch_answer(item)
        if matched:
            match['Overall'] += 1
            match[cate] += 1
            if matched == item[answer_key]:
                hit['Overall'] += 1
                hit[cate] += 1
                predictions.iloc[i, -1] = 1
            else:
                predictions.iloc[i, -1] = 0
        else:
            predictions.iloc[i, -1] = 2

    res = defaultdict(list)
    for k in tot.keys():
        res['Category'].append(k)
        res['tot'].append(tot[k])
        res['match'].append(match[k])
        res['hit'].append(hit[k])
        res['match_rate'].append(match[k] / tot[k] * 100)
        if match[k] == 0:
            res['acc (match)'].append(0)
            res['acc'].append(0)
        else:
            res['acc (match)'].append(hit[k] / match[k] * 100)
            res['acc'].append(hit[k] / tot[k] * 100)
    res = pd.DataFrame(res)
    return res

# Extract answer from multiple rolling records wo chatgpt
def eval_sub_data(sub_data):
    for i in range(len(sub_data)):
        status = sub_data.iloc[i]['status']
        if status in [0, 2]:
            return 0
    return 1

# Accuracy Report
def report_acc(df, group='category'):
    assert 'split' in df
    assert group in [None, 'category', 'l2-category']

    res = defaultdict(list)
    # res['split'] = ['full', 'dev', 'test']
    res['split'] = ['dev']
    if group is None:
        res['overall'] = [
            # np.mean(df['hit']),
            np.mean(df[df['split'] == 'dev']['hit']),
            # np.mean(df[df['split'] == 'test']['hit'])
        ]
        return pd.DataFrame(res)

    elif group in df:
        abilities = list(set(df[group]))
        abilities.sort()
        for ab in abilities:
            sub_df = df[df[group] == ab]
            res[ab] = [
                # np.mean(sub_df['hit']),
                np.mean(sub_df[sub_df['split'] == 'dev']['hit']),
                # np.mean(sub_df[sub_df['split'] == 'test']['hit'])
            ]
        return pd.DataFrame(res)

def mmbench_eval(args, dataset_path, result_path, score_path, dataset_split, prediction_path=None, result_dataset=None):
    if result_dataset is None and prediction_path is None:
        raise ValueError('Either prediction_path or result_dataset should be provided.')
    elif result_dataset is not None:
        predictions = result_dataset.to_pandas()
    else:
        predictions = read_with_orjsonl(prediction_path)
        predictions = pd.DataFrame(predictions)
    
    origin_dataset = load_from_disk(dataset_path)[dataset_split]
    predictions = predictions.sort_values(by=['input_index'])
    predictions = predictions.drop_duplicates(subset=['input_index'])
    if len(predictions) != len(origin_dataset):
        return len(predictions), len(origin_dataset), None, None

    if args.dataset_name in ['MMMU_VAL_MultiChoice', 'MMMU_TEST_MultiChoice', 'MathVista_MultiChoice','MMLU', 'PIQA_VAL']:
        predictions['index'] = predictions['input_index']
        predictions.set_index('input_index', inplace=True)
        # origin_dataset = origin_dataset.rename_column("subset", "category")
        for key in [name for name in origin_dataset.column_names if name not in ['decoded_image'] + [f'image_{i}' for i in range(1, 8)]]:
            predictions[key] = origin_dataset[key]
    else:
        predictions['index'] = origin_dataset['index']
        predictions.set_index('input_index', inplace=True)
        for key in origin_dataset.column_names:
            predictions[key] = origin_dataset[key]

    # without rolling
    predictions['status'] = [0 for _ in range(len(predictions))] # 0 extract wrong choice, 1 extract right choice, 2 cannot extract choice
    if "ScienceQA" in args.dataset_name:
        res_without_rolling_subject = prefetch_acc(predictions, key="subject")
        res_without_rolling_grade = prefetch_acc(predictions, key="grade").iloc[1:]
        # 基于第一列进行排序
        res_without_rolling_grade = res_without_rolling_grade.sort_values(by=['Category'], key=lambda x: x.str[5:].astype(int))
        res_without_rolling_grade.loc['G1-6'] = res_without_rolling_grade.iloc[:6].sum()
        res_without_rolling_grade.iloc[-1, 4] = res_without_rolling_grade.iloc[-1, 2] / res_without_rolling_grade.iloc[-1, 1] * 100
        res_without_rolling_grade.iloc[-1, 5] = res_without_rolling_grade.iloc[-1, 3] / res_without_rolling_grade.iloc[-1, 2] * 100
        res_without_rolling_grade.iloc[-1, 6] = res_without_rolling_grade.iloc[-1, 3] / res_without_rolling_grade.iloc[-1, 1] * 100
        res_without_rolling_grade.iloc[-1, 0] = 'G1-6'
        res_without_rolling_grade.loc['G7-12'] = res_without_rolling_grade.iloc[6:-1].sum()
        res_without_rolling_grade.iloc[-1, 4] = res_without_rolling_grade.iloc[-1, 2] / res_without_rolling_grade.iloc[-1, 1] * 100
        res_without_rolling_grade.iloc[-1, 5] = res_without_rolling_grade.iloc[-1, 3] / res_without_rolling_grade.iloc[-1, 2] * 100
        res_without_rolling_grade.iloc[-1, 6] = res_without_rolling_grade.iloc[-1, 3] / res_without_rolling_grade.iloc[-1, 1] * 100
        res_without_rolling_grade.iloc[-1, 0] = 'G7-12'
        res_without_rolling_grade.reset_index(drop=True, inplace=True)
        grade_row_index = list(range(len(res_without_rolling_grade)))
        res_without_rolling_grade = res_without_rolling_grade.reindex(grade_row_index[-2:] + grade_row_index[:-2])
        res_without_rolling_topic = prefetch_acc(predictions, key="topic").iloc[1:]
        res_without_rolling = pd.concat([res_without_rolling_subject, res_without_rolling_grade, res_without_rolling_topic], ignore_index=True)
    else:
        # if args.dataset_name in ["MMMU_VAL_MultiChoice", "MathVista_MultiChoice"]:
        if args.dataset_name in ["MathVista_MultiChoice"]:
            res_without_rolling = prefetch_acc(predictions, answer_key='answer_transformed')
        elif args.dataset_name in ["MMLU"]:
            res_without_rolling = prefetch_acc(predictions, key="subject", answer_key='answer_transformed')
        elif args.dataset_name in ["PIQA_VAL"]:
            predictions['category'] = ["all" for _ in range(len(predictions))]
            res_without_rolling = prefetch_acc(predictions, answer_key='answer_transformed')
        else:
            res_without_rolling = prefetch_acc(predictions)
    

    if "MMBench" in args.dataset_name:
        # with rolling
        predictions_main = deepcopy(predictions[predictions['index'] < int(1e6)])
        res_data = {}
        for key in predictions.columns:
            res_data[key] = []
        res_data = pd.DataFrame(res_data)

        result = {}
        hit, tot = 0, 0
        for i in range(len(predictions_main)):
            # Dealing with the normal part
            item_main = predictions_main.iloc[i]
            idx = item_main['index']
            sub_data = predictions[predictions['index'] % int(1e6) == idx]
            ret = eval_sub_data(sub_data)
            sub_data_copy = deepcopy(sub_data)
            sub_data_copy['hit'] = [1 if ret == 1 else 0 for _ in range(len(sub_data))]
            res_data = pd.concat([res_data, sub_data_copy], ignore_index=True)

            result[idx] = ret
            hit += ret
            tot += 1

        res_data['index'] = [int(index) for index in res_data['index']]
        res_data['image_index'] = [int(index) for index in res_data['image_index']]
        res_data['status'] = [int(status) for status in res_data['status']]

        predictions_main['hit'] = [result[i] for i in predictions_main['index']]

        overall = report_acc(predictions_main, None)
        l2 = report_acc(predictions_main, 'l2-category')
        leaf = report_acc(predictions_main, 'category')

        match_rate, acc_match, acc = list(res_without_rolling.iloc[0][-3:])
        rolling_acc = hit / tot * 100

        # res_data = res_data.to_dict(orient='records')
        # predictions_main = predictions_main.to_dict(orient='records')
        overall = overall.to_dict(orient='records')
        l2 = l2.to_dict(orient='records')
        leaf = leaf.to_dict(orient='records')

        scores = {
            "overall": {
                "match_rate": match_rate,
                "acc_match": acc_match,
                "acc": acc,
                "rolling_acc": rolling_acc
            },
            "details":{
                "overall": overall,
                "l2": l2,
                "leaf": leaf,
            },
        }

        os.makedirs(os.path.join(args.result_dir, f"{args.dataset_name}"), exist_ok=True)
        # if in_training:
        #     result_path = result_path[:-4] + '_training.csv'
        #     score_path = score_path[:-5] + '_training.json'
        prior_column_names = ['index', 'status', 'hit', 'response', 'response_post']
        res_data = res_data[prior_column_names + [column_name for column_name in res_data.columns.tolist() if column_name not in prior_column_names]]
        res_data.to_csv(result_path, index=False)
        print(f"Saved results to {result_path}")
        predictions_main = predictions_main[prior_column_names + [column_name for column_name in predictions_main.columns.tolist() if column_name not in prior_column_names]]
        predictions_main.to_csv(result_path[:-4] + '_main.csv', index=False)
        print(f"Saved main results to {result_path[:-4] + '_main.csv'}")

        res_without_rolling = res_without_rolling.to_dict(orient='records')
        scores["details_without_rolling"] = res_without_rolling
        with open(score_path, "w", encoding="utf-8-sig") as f:
            json.dump(scores, f, indent=4, ensure_ascii=False)
        print(f"Saved scores to {score_path}")
        print("########################################")
        scores.pop("details_without_rolling")
        # print(f"{args.dataset_name}_{args.prompt_setting}_{args.template_index}\n{scores}")
        print(f"{args.dataset_name}_{args.prompt_setting}_{args.template_index}\nAcc\tRolling Acc\n{scores['overall']['acc']:.2f}\t{scores['overall']['rolling_acc']:.2f}")
        print("########################################")
    else:
        os.makedirs(os.path.join(args.result_dir, f"{args.dataset_name}"), exist_ok=True)
        score_path = f"{score_path[:-12]}_scores.csv"
        res_without_rolling.to_csv(score_path, index=False)
        print(f"Saved scores to {score_path}")
        print("########################################")
        # print(f"{args.dataset_name}_{args.prompt_setting}_{args.template_index}\n{res_without_rolling}")
        print(f"{args.dataset_name}_{args.prompt_setting}_{args.template_index}\nAcc\n{res_without_rolling.iloc[0, -1]:.2f}")
        if args.dataset_name in ["MMMU_VAL_MultiChoice", "MathVista_MultiChoice"]:
            print(f"\nCorrect\tTotal\n{res_without_rolling.iloc[0, 3]}\t{res_without_rolling.iloc[0, 1]}")
        print("########################################")

        predictions.to_csv(result_path, index=False)
        print(f"Saved results to {result_path}")
    # if in_training:
    #     return match_rate, acc_match, acc, rolling_acc