# Modified from: https://github.com/MMStar-Benchmark/MMStar/blob/main/eval/vlmeval/evaluate/mmstar.py

from copy import deepcopy
import os
import csv
import json
import pickle
import argparse
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO,)

class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
                            np.int16, np.int32, np.int64, np.uint8,
                            np.uint16, np.uint32, np.uint64)):
            return int(obj)
        elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
            return {'real': obj.real, 'imag': obj.imag}
        elif isinstance(obj, (np.ndarray,)):
            return obj.tolist()
        elif isinstance(obj, (np.bool_)):
            return bool(obj)
        elif isinstance(obj, (np.void)): 
            return None
        return json.JSONEncoder.default(self, obj)

# LOAD & DUMP
def dump(data, f, **kwargs):
    def dump_pkl(data, pth, **kwargs):
        pickle.dump(data, open(pth, 'wb'))

    def dump_json(data, pth, **kwargs):
        json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder)

    def dump_jsonl(data, f, **kwargs):
        lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
        with open(f, 'w', encoding='utf8') as fout:
            fout.write('\n'.join(lines))

    def dump_xlsx(data, f, **kwargs):
        data.to_excel(f, index=False, engine='xlsxwriter')

    def dump_csv(data, f, quoting=csv.QUOTE_ALL):
        data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)

    def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
        data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)

    os.makedirs(os.path.dirname(f), exist_ok=True)
    handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
    suffix = f.split('.')[-1]
    return handlers[suffix](data, f, **kwargs)

def MMStar_eval(eval_file, save_file):
    MMStar_score_l2 = {
        'coarse perception': {
            'image scene and topic': 0,
            'image style & quality': 0,
            'image emotion': 0
        },
        'fine-grained perception': {
            'object counting': 0,
            'recognition': 0,
            'localization': 0
        },
        'instance reasoning': {
            'single-instance reasoning': 0,
            'cross-instance attribute reasoning': 0,
            'cross-instance relation reasoning': 0
        },
        'logical reasoning': {
            'code & sequence reasoning': 0,
            'diagram reasoning': 0,
            'common reasoning': 0
        },
        'science & technology': {
            'biology & chemistry & physics': 0,
            'electronics & energy & mechanical eng.': 0,
            'geography & earth science & agriculture': 0
        },
        'math': {
            'geometry': 0,
            'numeric commonsense and calculation': 0,
            'statistical reasoning': 0
        },
    }
    MMStar_counter = deepcopy(MMStar_score_l2)
    
    dataset = load_dataset("Lin-Chen/MMStar", "val")["val"]
    answers = [json.loads(q) for q in open(eval_file)]
    predictions = {pred['question_id']: pred["text"] for pred in answers}
    fail_counter = 0
    for line in dataset:
        index = line['index']
        predict = predictions.get(index, '')
        answers = str(line['answer'])
        # ori_bench = str(line['bench'])
        category = str(line['category'])
        l2_category = str(line['l2_category'])
        MMStar_counter[category][l2_category] += 1

        answer = answers.lower().strip().replace('\n', ' ')
        predict = predict.lower().strip().replace('\n', ' ')
        # if ori_bench == 'MathVista' and answer not in ['a', 'b', 'c', 'd']:
        #     if answer in predict:
        #         MMStar_score_l2[category][l2_category] += 1
        # else:
        try:
            if answer == predict[0]:
                MMStar_score_l2[category][l2_category] += 1
            elif predict[0] == '(' and answer == predict[1]:
                MMStar_score_l2[category][l2_category] += 1
            elif predict[0:7] == 'option ' and answer == predict[7]:
                MMStar_score_l2[category][l2_category] += 1
            elif predict[0:14] == 'the answer is ' and answer == predict[14]:
                MMStar_score_l2[category][l2_category] += 1
        except Exception as e:
            if predict[0] in "abcd" or predict[0] == '(' or predict[0:7] == 'option ' or predict[0:14] == 'the answer is ':
                fail_counter += 1

    MMStar_score = {}
    MMStar_score['final score'] = 0
    for k, v in MMStar_score_l2.items():
        MMStar_score[k] = 0
        for l2_k, l2_v in v.items():
            MMStar_score[f'{k}({l2_k})'] = float(l2_v) / \
                float(MMStar_counter[k][l2_k])
            MMStar_score[k] += l2_v
        MMStar_score['final score'] += MMStar_score[k]
        MMStar_score[k] = float(MMStar_score[k]) / 250.0
    MMStar_score['final score'] = float(MMStar_score['final score']) / 1500.0
    MMStar_score['fail extraction'] = fail_counter / 1500.0

    score_pth = save_file
    dump(MMStar_score, score_pth)
    logger.info(
        f'MMStar_eval successfully finished evaluating {eval_file}, results saved in {score_pth}')
    logger.info('Score: ')
    for key, value in MMStar_score.items():
        print('{}: {}'.format(key, value*100))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--result-file", type=str)
    parser.add_argument("--result-save", type=str)
    args = parser.parse_args()

    MMStar_eval(args.result_file, args.result_save)