from collections import defaultdict

categories = ["all","detection","counting","logic","attribute","ocr","position","color","shape"]

num = {
    "all":1000,
    "detection":317,
    "counting":272,
    "attribute" : 241,
    "shape":177,
    "color":64,
    "logic":174,
    "ocr":75,
    "position":92
}

num_qa = {
    "all": 799,
    "detection": 295,
    "counting": 142,
    "attribute": 238,
    "shape": 166,
    "color": 85,
    "logic": 146,
    "ocr": 64,
    "position":101
}

permutation_list = [
    [0,1,2,3],
    [0,2,1,3],
    [1,0,2,3],
    [1,2,0,3],
    [2,0,1,3],
    [2,1,0,3]
]
permutation_list2 = [
    ['A','B','C','D'],
    ['A','C','B','D'],
    ['B','A','C','D'],
    ['C','A','B','D'],
    ['B','C','A','D'],
    ['C','B','A','D']
]

def cal_acc(answers):
    acc = 0
    acc_dict = {
        "all":0,
        "detection":0,
        "counting":0,
        "logic":0,
        "attribute": 0,
        "shape":0,
        "color":0,
        "ocr":0,
        "position":0
    }
    for i in answers:
        if "category" in i.keys():
            if type(i["category"]) == list:
                for k in i["category"]:
                    acc_dict[k] += sum(i['corrects'])
                if 'color' in i['category'] or 'shape' in i['category']:
                    acc_dict["attribute"] += sum(i['corrects'])

            else:
                if i['category'] == 'color' or i['category'] == "shape":
                    acc_dict["attribute"] += sum(i['corrects'])
                acc_dict[i['category']] += sum(i['corrects'])
        else:
            if type(i['sample']['category']) == list:
                for k in i['sample']['category']:
                    acc_dict[k] += sum(i['corrects'])
                if 'color' in i['sample']['category'] or 'shape' in i['sample']['category']:
                    acc_dict["attribute"] += sum(i['corrects'])
            else:
                if i['sample']['category'] == 'color' or i['sample']['category'] == "shape":
                    acc_dict["attribute"] += sum(i['corrects'])
                acc_dict[i['sample']['category']] += sum(i['corrects'])
        acc += sum(i['corrects'])
    acc = acc / (len(answers) * 6)
    for key,value in acc_dict.items():

        acc_dict[key] /= 6 * num[key]
    acc_dict['all'] = acc
    return acc_dict

def cal_hallu(answers):
    hallu = 0
    hallu_dict = {
        "all":0,
        "detection":0,
        "counting":0,
        "logic":0,
        "attribute": 0,
        "shape":0,
        "color":0,
        "ocr":0,
        "position":0
    }
    for i in answers:
        if "category" in i.keys():
            if type(i['category']) == list:
                for k in i['category']:
                    hallu_dict[k] += sum(1 if t == -1 else 0 for t in i['scores'])
                if 'color' in i['category'] or "shape" in i['category']:
                    hallu_dict['attribute'] += sum(1 if t == -1 else 0 for t in i['scores'])

            else:
                if i['category'] == "color" or i['category'] == "shape":
                    hallu_dict['attribute'] += sum(1 if t == -1 else 0 for t in i['scores'])
                hallu_dict[i['category']] += sum(1 if t == -1 else 0 for t in i['scores'])

        else:
            if type(i['sample']['category']) == list:
                for k in i['sample']['category']:
                    hallu_dict[k] += sum(1 if t == -1 else 0 for t in i['scores'])
                if 'color' in i['sample']['category'] or "shape" in i['sample']['category']:
                    hallu_dict['attribute'] += sum(1 if t == -1 else 0 for t in i['scores'])
            else:
                if i['sample']['category'] == 'color' or i['sample']['category']  == 'shape':
                    hallu_dict['attribute'] += sum(1 if t == -1 else 0 for t in i['scores'])
                hallu_dict[i['sample']['category']] += sum(1 if t == -1 else 0 for t in i['scores'])

        hallu += sum([1 if k == -1 else 0 for k in i['scores']])
    hallu = hallu / (len(answers) * 6)
    for k, v in hallu_dict.items():
        hallu_dict[k] /= 6 * num[k]
    hallu_dict['all'] = hallu
    return hallu_dict

def cal_hallu_score(answers):
    acc_dict = cal_acc(answers)
    hallu_dict = cal_hallu(answers)
    return {"acc":acc_dict,
            "hallu":hallu_dict}

def cal_qa_acc(answers):
    acc = 0
    acc_dict = {
        "all":0,
        "detection":0,
        "counting":0,
        "logic":0,
        "attribute": 0,
        "shape":0,
        "color":0,
        "ocr":0,
        "position":0
    }
    for i in answers:
        if "category" in i.keys():
            if type(i["category"]) == list:
                for k in i["category"]:
                    acc_dict[k] += i['corrects']
                if 'color' in i['category'] or 'shape' in i['category']:
                    acc_dict['attribute'] += i['corrects']
            else:
                if i['category'] == 'color' or i['category'] == "shape":
                    acc_dict["attribute"] += i['corrects']
                acc_dict[i['category']] += i['corrects']
        else:
            if type(i['sample']['category']) == list:
                for k in i['sample']['category']:
                    acc_dict[k] += i['corrects']
                if 'color' in i['sample']['category']  or 'shape' in i['sample']['category']:
                    acc_dict['attribute'] += i['corrects']
            else:
                if i['sample']['category'] == 'color' or i['sample']['category'] == "shape":
                    acc_dict["attribute"] += sum(i['corrects'])
                acc_dict[i['sample']['category']] += sum(i['corrects'])
        acc += i['corrects']
    acc = acc / len(answers)
    for key, value in acc_dict.items():
        acc_dict[key] /= num[key]
    acc_dict['all'] = acc
    return acc_dict

def cal_qa_hallu(answers):
    hallu = 0
    hallu_dict = {
        "all":0,
        "detection":0,
        "counting":0,
        "logic":0,
        "attribute": 0,
        "shape":0,
        "color":0,
        "ocr":0,
        "position":0
    }
    for i in answers:
        if "category" in i.keys():
            if type(i['category']) == list:
                for k in i['category']:
                    hallu_dict[k] += 1 if i['score'] == -1 else 0
                if 'color' in i['category'] or "shape" in i['category']:
                    hallu_dict['attribute'] += 1 if i['score'] == -1 else 0

            else:
                if i['category'] == "color" or i['category'] == "shape":
                    hallu_dict['attribute'] += 1 if i['score'] == -1 else 0
                hallu_dict[i['category']] += 1 if i['score'] == -1 else 0

        else:
            if type(i['sample']['category']) == list:
                for k in i['sample']['category']:
                    hallu_dict[k] += 1 if i['score'] == -1 else 0
                if 'color' in i['sample']['category'] or "shape" in i['sample']['category']:
                    hallu_dict['attribute'] += 1 if i['score'] == -1 else 0
            else:
                if i['sample']['category'] == 'color' or i['sample']['category']  == 'shape':
                    hallu_dict['attribute'] += 1 if i['score'] == -1 else 0
                hallu_dict[i['sample']['category']] += 1 if i['score'] == -1 else 0

        hallu += 1 if i['score'] == -1 else 0
    hallu = hallu / len(answers)
    for k, v in hallu_dict.items():
        hallu_dict[k] /= num[k]
    hallu_dict['all'] = hallu
    return hallu_dict


def cal_qa_score(answers):
    acc_dict = cal_qa_acc(answers)
    hallu_dict = cal_qa_hallu(answers)
    return {"acc":acc_dict,
            "hallu":hallu_dict}

class ModelEvalResult:
    def load_json_file(self, file_dict, mcq=False):
        import json
        with open("dataset.json", "r", encoding="utf-8") as f:
            dataset = json.load(f)
        eval_result = {}
        for k,v in file_dict.items():
            with open(v,'r',encoding='utf-8') as f:
                output_file = json.load(f)
            if mcq:
                print(file_dict)
                print(len(dataset), ' | ', len(output_file))
                for i in range(len(dataset)):
                    if 'sample' in output_file[i].keys():
                        output_file[i]['sample'] = dataset[i]
                    else:
                        if dataset[i]['image_path'] != output_file[i]['image_path']:
                            print(f'ERROR :  at ', dataset[i]['id'])
                        else:
                            output_file[i]['category'] = dataset[i]['category']
                            output_file[i]['question'] = dataset[i]['question']
                            output_file[i]['options'] = dataset[i]['options']
                            output_file[i]['ground_truth'] = dataset[i]['ground_truth']
            eval_result[k] = output_file
        return eval_result
    def cal_total_result(self):
        keys1 = self.result_dict['mcq'].keys()
        keys2 = self.result_dict['qa'].keys()
        keys3 = [x for x in keys1 if x in keys2]
        total_result = {

        }
        for key in keys3:
            cnt_result = {
                "acc": {
                    k: (self.result_dict['mcq'][key]['acc'][k] * num[k] + self.result_dict['qa'][key]['acc'][k] * num_qa[k]) / (
                                num[k] + num_qa[k]) for k in categories
                },
                "hallu": {
                    k: (self.result_dict['mcq'][key]['hallu'][k] * num[k] + self.result_dict['qa'][key]['hallu'][k] * num_qa[k]) / (
                            num[k] + num_qa[k]) for k in categories
                }
            }
            total_result[key] = cnt_result
        return total_result
    def __init__(self, model_name, mcq_result_file_dict, qa_result_file_dict):
        self.model_name = model_name
        self.result_dataset_dict = {
            "mcq": self.load_json_file(mcq_result_file_dict,True),
            "qa": self.load_json_file(qa_result_file_dict)
        }
        self.result_dict = {
            "mcq": {
                k: cal_hallu_score(v) for k,v in self.result_dataset_dict["mcq"].items()
            },
            "qa": {
                k: cal_qa_score(v) for k, v in self.result_dataset_dict["qa"].items()
            },
        }
        self.total_result = self.cal_total_result()



if __name__ == '__main__':
    # Example
    all_data = {
        "Gemini2.5 flash": {
            "mcq": {
                "thinking": "./final_result_woCoT/gemini-2.5-flash_shuffle_20250806_120622.json",
                "wo": "gemini-2.5-flash-nothinking_shuffle_20250919_013211.json"
            },
            "qa": {
                "thinking": "final_result_qa/gemini-2.5-flash_qa_20250918_080644.json",
                "wo": "final_result_qa/no_think_gemini-2.5-flash_qa_20250919_010117.json"
            }
        }
    }
    eval_results = {}
    for k,v in all_data.items():
        eval_results[k] = ModelEvalResult(k,v['mcq'],v['qa'])

    print(eval_results['Gemini2.5 flash'].total_result)
    print(eval_results['Gemini2.5 flash'].result_dict)
