import os
import json
from pllava.tasks.eval.eval_utils import (
    dump_json,
    load_json,
    EvalDataset,
)


def check_ans(pred, gt):
    flag = False
    
    pred_list = pred.lower().split(' ')
    pred_option, pred_content = pred_list[0], ' '.join(pred_list[1:])
    gt_list = gt.lower().split(' ')
    gt_option, gt_content = gt_list[0], ' '.join(gt_list[1:])
    if gt_content[-1] == '.':
        gt_content = gt_content[:-1]
    
    if not any([c in pred_option for c in 'abcdefgABCDEFG']):
        print(f"model doesn't follow instructions: {pred}")
    elif pred_option.replace('.', '') in gt_option:
        flag = True
    elif gt_option in pred_option:
        flag = True
        
    return flag

def save_results(result_list, save_path):

    final_res, acc_dict = {}, {}
    correct, total = 0, 0
    for res in result_list:
        task_type = res['task_type']
        if task_type not in acc_dict:
            acc_dict[task_type] = [0, 0] # correct, total
        acc_dict[task_type][1] += 1
        total += 1
        pred = res['pred']
        gt = res['gt']
        if check_ans(pred=pred, gt=gt):
            acc_dict[task_type][0] += 1
            correct += 1

    for k, v in acc_dict.items():
        final_res[k] = v[0] / v[1] * 100
        correct += v[0]
        total += v[1]    
    final_res['Avg'] = correct / total * 100

    all_results = {
        "acc_dict": acc_dict,
        "result_list": result_list
    }
    dump_json(all_results, save_path, 'all_results.json')
    dump_json(final_res, save_path, 'upload_leaderboard.json')

def load_results(save_path):
    all_results = load_json(save_path, 'all_results.json')
    if all_results is not None:
        result_list = all_results['result_list']
    else:
        result_list = None
    # json_data = load_json(save_path, 'all_results.json')['result_list']
    return result_list

class MVBenchDataset(EvalDataset):
    data_list_info = {
        # "task_type (sub task name)": ("json file name", "image/video prefix", "data_type", "bound")
        "Action Sequence": ("action_sequence.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
        "Action Prediction": ("action_prediction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
        "Action Antonym": ("action_antonym.json", "DATAS/MVBench/video/ssv2_video/", "video", False),
        "Fine-grained Action": ("fine_grained_action.json", "DATAS/MVBench/video/Moments_in_Time_Raw/videos/", "video", False),
        "Unexpected Action": ("unexpected_action.json", "DATAS/MVBench/video/FunQA_test/test/", "video", False),
        "Object Existence": ("object_existence.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
        "Object Interaction": ("object_interaction.json", "DATAS/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
        "Object Shuffle": ("object_shuffle.json", "DATAS/MVBench/video/perception/videos/", "video", False),
        "Moving Direction": ("moving_direction.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
        "Action Localization": ("action_localization.json", "DATAS/MVBench/video/sta/sta_video/", "video", True),  # has start & end
        "Scene Transition": ("scene_transition.json", "DATAS/MVBench/video/scene_qa/video/", "video", False),
        "Action Count": ("action_count.json", "DATAS/MVBench/video/perception/videos/", "video", False),
        "Moving Count": ("moving_count.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
        "Moving Attribute": ("moving_attribute.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
        "State Change": ("state_change.json", "DATAS/MVBench/video/perception/videos/", "video", False),
        "Fine-grained Pose": ("fine_grained_pose.json", "DATAS/MVBench/video/nturgbd/", "video", False),
        "Character Order": ("character_order.json", "DATAS/MVBench/video/perception/videos/", "video", False),
        "Egocentric Navigation": ("egocentric_navigation.json", "DATAS/MVBench/video/vlnqa/", "video", False),
        "Episodic Reasoning": ("episodic_reasoning.json", "DATAS/MVBench/video/tvqa/frames_fps3_hq/", "frame", True),  # has start & end, read frame
        "Counterfactual Inference": ("counterfactual_inference.json", "DATAS/MVBench/video/clevrer/video_validation/", "video", False),
    }
    data_dir = "DATAS/MVBench/json"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        data_list_info = self.data_list_info
        data_dir = self.data_dir

        self.data_list = []
        for k, v in data_list_info.items():
            with open(os.path.join(data_dir, v[0]), 'r') as f:
                json_data = json.load(f)
            for data in json_data:
                self.data_list.append({
                    'task_type': k,
                    'prefix': v[1],
                    'data_type': v[2],
                    'bound': v[3],
                    'data': data
                })
        # self.data_list = self.data_list[:100] # for debug
        self.decord_method = {
            'video': self.read_video,
            'gif': self.read_gif,
            'frame': self.read_frame,
        }
                
        # # transform
        # crop_size = resolution
        # scale_size = resolution
        # input_mean = [0.48145466, 0.4578275, 0.40821073]
        # input_std = [0.26862954, 0.26130258, 0.27577711]
        # self.transform = T.Compose([
        #     GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
        #     GroupCenterCrop(crop_size),
        #     Stack(),
        #     ToTorchFormatTensor(),
        #     GroupNormalize(input_mean, input_std) 
        # ])
    
    def __getitem__(self, idx):
        question, answer = self.qa_template(self.data_list[idx]['data'])
        task_type = self.data_list[idx]['task_type']
        decord_method = self.decord_method[self.data_list[idx]['data_type']]
        bound = None
        if self.data_list[idx]['bound']:
            bound = (
                self.data_list[idx]['data']['start'],
                self.data_list[idx]['data']['end'],
            )
        video_path = os.path.join(self.data_list[idx]['prefix'], self.data_list[idx]['data']['video'])


        # images_group = decord_method(video_path, bound)
        try: # might be problem with decord
            images_group = decord_method(video_path, bound)
        except Exception as e:
            print(f'error decoding {video_path}')
            task_type = 'error_reading_video'
            images_group = None

        return {
            'video_path': video_path, 
            'video_pils': images_group, # some might use the original pils and do their own transforms
            'question': question, 
            'answer': answer,
            'task_type': task_type,
        }
        

    def qa_template(self, data):
        question = f"Question: {data['question']}\n"
        question += "Options:\n"
        answer = data['answer']
        answer_idx = -1
        for idx, c in enumerate(data['candidates']):
            question += f"({chr(ord('A') + idx)}) {c}\n"
            if c == answer:
                answer_idx = idx
        question = question.rstrip()
        answer = f"({chr(ord('A') + answer_idx)}) {answer}"
        return question, answer

