import huggingface_hub
from huggingface_hub import snapshot_download
from ..smp import *
from .video_concat_dataset import ConcatVideoDataset
from .video_base import VideoBaseDataset
from .utils import build_judge, DEBUG_MESSAGE
from ..utils import track_progress_rich
import torchvision.transforms as T
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from decord import VideoReader, cpu
import pandas as pd
import imageio
import cv2
import zipfile
import os
import glob
from .utils.mlvu import *

FAIL_MSG = 'Failed to obtain answer via API.'


class MLVU(ConcatVideoDataset):
    def __init__(self, dataset='MLVU'):
        self.DATASET_SETS[dataset] = ['MLVU_MCQ', 'MLVU_OpenEnded']
        self.type_data_dict = {
            'M-Avg':['plotQA', 'needle', 'ego', 'count', 'anomaly_reco', 'topic_reasoning'],
            'G-Avg':['sub_scene', 'summary']
        }
        super().__init__(dataset=dataset)

    @classmethod
    def supported_datasets(cls):
        return ['MLVU']

    def evaluate(self, eval_file, **judge_kwargs):
        result = super().evaluate(eval_file=eval_file, **judge_kwargs)
        suffix = eval_file.split('.')[-1]
        score_file = eval_file.replace(f'.{suffix}', '_acc.csv')
        for key in self.type_data_dict:
            result.loc[key] = 0.0
            for name, item in result.iterrows():
                if name in self.type_data_dict[key]:
                    result.loc[key, 'success'] += item['success']
                    result.loc[key, 'overall'] += item['overall']
            if key == 'G-Avg':
                result.loc[key, 'acc'] = round(
                    result.loc[key, 'success'] / result.loc[key, 'overall'], 2
                )
            else:
                result.loc[key, 'acc'] = round(
                    result.loc[key, 'success'] / result.loc[key, 'overall'] * 100, 1
                )
        result = result.reset_index().rename(columns={'index': 'task'})
        dump(result, score_file)
        return result


class MLVU_MCQ(VideoBaseDataset):

    MD5 = 'bb5c37e7cf8d43fc9a25c23d2b4633f5'
    BASE_SYS = 'Carefully watch this video and pay attention to every detail. '
    SYS = BASE_SYS + 'Based on your observations, select the best option that accurately addresses the question.'
    TYPE = 'Video-MCQ'

    def __init__(self, dataset='MLVU_MCQ'):
        self.type_data_list = {
            'plotQA': ('1_plotQA.json', './MLVU/video/1_plotQA', 'MCQ'),
            'needle': ('2_needle.json', './MLVU/video/2_needle', 'MCQ'),
            'ego': ('3_ego.json', './MLVU/video/3_ego', 'MCQ'),
            'count': ('4_count.json', './MLVU/video/4_count', 'MCQ'),
            'order': ('5_order.json', './MLVU/video/5_order', 'MCQ'),
            'anomaly_reco': ('6_anomaly_reco.json', './MLVU/video/6_anomaly_reco', 'MCQ'),
            'topic_reasoning': ('7_topic_reasoning.json', './MLVU/video/7_topic_reasoning', 'MCQ'),
        }
        super().__init__(dataset=dataset)

    @classmethod
    def supported_datasets(cls):
        return ['MLVU_MCQ']

    def prepare_dataset(self, dataset_name='MLVU_MCQ', repo_id='MLVU/MVLU'):
        def check_integrity(pth):
            data_file = osp.join(pth, f'{dataset_name}.tsv')

            if not os.path.exists(data_file):
                return False

            if md5(data_file) != self.MD5:
                return False

            data = load(data_file)
            for idx, item in data.iterrows():
                if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
                    return False
            return True

        if modelscope_flag_set():
            repo_id = "AI-ModelScope/MLVU"

        cache_path = get_cache_path(repo_id)
        if cache_path is not None and check_integrity(cache_path):
            dataset_path = cache_path
        else:
            def generate_tsv(pth):
                data_file = osp.join(pth, f'{dataset_name}.tsv')
                if os.path.exists(data_file) and md5(data_file) == self.MD5:
                    return
                json_data_dir = os.path.join(dataset_path, 'MLVU', 'json')
                self.data_list = []
                for k, v in self.type_data_list.items():
                    with open(os.path.join(json_data_dir, v[0]), 'r') as f:
                        json_data = json.load(f)
                    for data in json_data:
                        self.data_list.append({
                            'task_type': k,
                            'prefix': v[1],
                            'duration': data['duration'],
                            'video': data['video'],
                            'question': data['question'],
                            'answer': data['answer'],
                            'candidates': data['candidates'],
                        })

                data_df = pd.DataFrame(self.data_list)
                data_df = data_df.assign(index=range(len(data_df)))
                data_df.to_csv(data_file, sep='\t', index=False)

            if modelscope_flag_set():
                from modelscope import dataset_snapshot_download
                dataset_path = dataset_snapshot_download(dataset_id=repo_id)
            else:
                hf_token = os.environ.get('HUGGINGFACE_TOKEN')
                huggingface_hub.login(hf_token)
                dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')

            generate_tsv(dataset_path)

        data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
        return dict(root=dataset_path, data_file=data_file)

    def qa_template(self, data):
        question = f"Question: {data['question']}\n"
        question += 'Options:\n'
        answer = data['answer']
        answer_idx = -1
        for idx, c in enumerate(eval(data['candidates'])):
            question += f"({chr(ord('A') + idx)}) {c}\n"
            if c == answer:
                answer_idx = idx
        question = question.rstrip()
        answer = f"({chr(ord('A') + answer_idx)}) {answer}"
        return question, answer

    def save_video_frames(self, line, num_frames=8, fps=-1):
        suffix = line['video'].split('.')[-1]
        video = line['video'].replace(f'.{suffix}','')
        vid_path = osp.join(self.data_root, line['prefix'], line['video'])
        vid = decord.VideoReader(vid_path)
        video_info = {
            'fps': vid.get_avg_fps(),
            'n_frames': len(vid),
        }
        if num_frames > 0 and fps < 0:
            step_size = len(vid) / (num_frames + 1)
            indices = [int(i * step_size) for i in range(1, num_frames + 1)]
            frame_paths = self.frame_paths(video, num_frames)
        elif fps > 0:
            # not constrained by num_frames, get frames by fps
            total_duration = video_info['n_frames'] / video_info['fps']
            required_frames = int(total_duration * fps)
            step_size = video_info['fps'] / fps
            indices = [int(i * step_size) for i in range(required_frames)]
            frame_paths = self.frame_paths_fps(video, len(indices), fps)

        flag = np.all([osp.exists(p) for p in frame_paths])

        if not flag:
            images = [vid[i].asnumpy() for i in indices]
            images = [Image.fromarray(arr) for arr in images]
            for im, pth in zip(images, frame_paths):
                if not osp.exists(pth):
                    im.save(pth)

        return frame_paths

    def save_video_into_images(self, line, num_frames, fps):
        frame_paths = self.save_video_frames(line, num_frames, fps)
        return frame_paths

    def build_prompt(self, line, num_frames, video_llm, fps=-1):
        if isinstance(line, int):
            assert line < len(self)
            line = self.data.iloc[line]

        question, answer = self.qa_template(line)
        message = [dict(type='text', value=self.SYS, role='system')]
        message.append(dict(type='text', value=question))
        video_path = os.path.join(self.data_root, line['prefix'], line['video'])
        if video_llm:
            message.append(dict(type='video', value=video_path))
        else:
            img_frame_paths = self.save_video_into_images(line, num_frames, fps)
            for im in img_frame_paths:
                message.append(dict(type='image', value=im))
        message.append(dict(type='text', value='\nOnly give the best option.'))
        return message

    @classmethod
    def evaluate(self, eval_file, **judge_kwargs):
        assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file'

        tmp_file = eval_file.replace('.xlsx', '_tmp.pkl')
        score_file = eval_file.replace('.xlsx', '_score.xlsx')

        if not osp.exists(score_file):
            model = judge_kwargs.setdefault('model', 'chatgpt-0125')
            assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']

            if model == 'exact_matching':
                model = None
            elif gpt_key_set():
                model = build_judge(**judge_kwargs)
                if not model.working():
                    warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation')
                    warnings.warn(DEBUG_MESSAGE)
                    model = None
            else:
                warnings.warn('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
                model = None
            res = {} if not osp.exists(tmp_file) else load(tmp_file)
            res = {k: v for k, v in res.items() if FAIL_MSG not in v}

            data = load(eval_file)
            data_un = data[~pd.isna(data['prediction'])]

            for idx in data['index']:
                ans = data.loc[data['index'] == idx, 'answer'].values[0]
                pred = data.loc[data['index'] == idx, 'prediction'].values[0]
                options = eval(data.loc[data['index'] == idx, 'candidates'].values[0])
                answer_idx = -1
                for id, c in enumerate(options):
                    if c == ans:
                        answer_idx = id
                ans = f"({chr(ord('A') + answer_idx)}) {ans}"
                input_item = data.loc[data['index'] == idx].to_dict(orient='records')[0]
                for id, option_content in enumerate(eval(input_item['candidates'])):
                    input_item[chr(ord('A') + id)] = option_content
                    if option_content == input_item['answer']:
                        input_item['answer'] = chr(ord('A') + id)

                if FAIL_MSG in pred:
                    data.loc[idx, 'score'] = -1
                else:
                    data.loc[idx, 'score'] = int(check_ans_with_model(
                        pred, ans, model,
                        input_item,
                        'MLVU_MCQ'
                    ))

            rejected = [x for x in data['score'] if x == -1]

            print(
                f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, '
                f'failed to obtain the score for another {len(rejected)} questions. '
                f'Those questions will be counted as -1 score in ALL rating, and will not be counted in VALID rating.'
            )

            dump(data, score_file)

        rating = get_dimension_rating(score_file)
        return rating


class MLVU_OpenEnded(VideoBaseDataset):

    MD5 = 'cee573a3627c6ac434ded704c60511ba'
    BASE_SYS = 'Carefully watch this video and pay attention to every detail. '
    SYS = BASE_SYS + 'Based on your observations, answer the given questions.'
    TYPE = 'Video-VQA'

    def __init__(self, dataset='MLVU_OpenEnded'):
        self.type_data_list = {
            'sub_scene': ('8_sub_scene.json', './MLVU/video/8_sub_scene', 'VQA'),
            'summary': ('9_summary.json', './MLVU/video/9_summary', 'VQA')
        }
        super().__init__(dataset=dataset)

    @classmethod
    def supported_datasets(cls):
        return ['MLVU_OpenEnded']

    def prepare_dataset(self, dataset_name='MLVU_OpenEnded', repo_id='MLVU/MVLU'):
        def check_integrity(pth):
            data_file = osp.join(pth, f'{dataset_name}.tsv')

            if not os.path.exists(data_file):
                return False

            if md5(data_file) != self.MD5:
                return False

            data = load(data_file)
            for idx, item in data.iterrows():
                if not osp.exists(osp.join(pth, item['prefix'], item['video'])):
                    return False
            return True

        if modelscope_flag_set():
            repo_id = "AI-ModelScope/MLVU"

        cache_path = get_cache_path(repo_id)
        if cache_path is not None and check_integrity(cache_path):
            dataset_path = cache_path
        else:
            def generate_tsv(pth):
                data_file = osp.join(pth, f'{dataset_name}.tsv')
                if os.path.exists(data_file) and md5(data_file) == self.MD5:
                    return
                json_data_dir = os.path.join(dataset_path, 'MLVU', 'json')
                self.data_list = []
                for k, v in self.type_data_list.items():
                    with open(os.path.join(json_data_dir, v[0]), 'r') as f:
                        json_data = json.load(f)
                    for data in json_data:
                        self.data_list.append({
                            'task_type': k,
                            'prefix': v[1],
                            'duration': data['duration'],
                            'video': data['video'],
                            'question': data['question'],
                            'answer': data['answer'],
                            'scoring_points': data['scoring_points'] if 'scoring_points' in data else ''
                        })

                data_df = pd.DataFrame(self.data_list)
                data_df = data_df.assign(index=range(len(data_df)))
                data_df.to_csv(data_file, sep='\t', index=False)

            if modelscope_flag_set():
                from modelscope import dataset_snapshot_download
                dataset_path = dataset_snapshot_download(dataset_id=repo_id)
            else:
                hf_token = os.environ.get('HUGGINGFACE_TOKEN')
                huggingface_hub.login(hf_token)
                dataset_path = snapshot_download(repo_id=repo_id, repo_type='dataset')

            generate_tsv(dataset_path)

        data_file = osp.join(dataset_path, f'{dataset_name}.tsv')
        return dict(root=dataset_path, data_file=data_file)

    def qa_template(self, data):
        question = f"{data['question']}"
        answer = data['answer']
        return question, answer

    def save_video_frames(self, line, num_frames=8, fps=-1):
        suffix = line['video'].split('.')[-1]
        video = line['video'].replace(f'.{suffix}','')
        vid_path = osp.join(self.data_root, line['prefix'], line['video'])
        vid = decord.VideoReader(vid_path)
        video_info = {
            'fps': vid.get_avg_fps(),
            'n_frames': len(vid),
        }
        if num_frames > 0 and fps < 0:
            step_size = len(vid) / (num_frames + 1)
            indices = [int(i * step_size) for i in range(1, num_frames + 1)]
            frame_paths = self.frame_paths(video, num_frames)
        elif fps > 0:
            # not constrained by num_frames, get frames by fps
            total_duration = video_info['n_frames'] / video_info['fps']
            required_frames = int(total_duration * fps)
            step_size = video_info['fps'] / fps
            indices = [int(i * step_size) for i in range(required_frames)]
            frame_paths = self.frame_paths_fps(video, len(indices), fps)

        flag = np.all([osp.exists(p) for p in frame_paths])

        if not flag:
            images = [vid[i].asnumpy() for i in indices]
            images = [Image.fromarray(arr) for arr in images]
            for im, pth in zip(images, frame_paths):
                if not osp.exists(pth):
                    im.save(pth)

        return frame_paths

    def save_video_into_images(self, line, num_frames, fps):
        frame_paths = self.save_video_frames(line, num_frames, fps)
        return frame_paths

    def build_prompt(self, line, num_frames, video_llm, fps=-1):
        if isinstance(line, int):
            assert line < len(self)
            line = self.data.iloc[line]

        question, answer = self.qa_template(line)
        message = [dict(type='text', value=self.SYS, role='system')]
        message.append(dict(type='text', value=question))
        video_path = os.path.join(self.data_root, line['prefix'], line['video'])
        if video_llm:
            message.append(dict(type='video', value=video_path))
        else:
            img_frame_paths = self.save_video_into_images(line, num_frames, fps)
            for im in img_frame_paths:
                message.append(dict(type='image', value=im))
        return message

    @classmethod
    def evaluate(self, eval_file, **judge_kwargs):

        model = judge_kwargs['model'] if 'model' in judge_kwargs else judge_kwargs.setdefault('model', 'gpt-4-0125')
        if model != 'gpt-4-0125':
            print('MLVU Open Ended default using gpt-4-0125! So judge model is changed to gpt-4-0125')
            judge_kwargs['model'] = 'gpt-4-0125'

        suffix = eval_file.split('.')[-1]
        score_file = eval_file.replace(f'.{suffix}', f'_{model}_score.xlsx')
        tmp_file = eval_file.replace(f'.{suffix}', f'_{model}.pkl')
        nproc = judge_kwargs.pop('nproc', 4)

        if not osp.exists(score_file):
            data = load(eval_file)
            model_dict = {
                'sub_scene': build_judge(system_prompt=system_prompt_sub_scene, **judge_kwargs),
                'summary': build_judge(system_prompt=system_prompt_summary, **judge_kwargs)
            }
            lt = len(data)
            lines = [data.iloc[i] for i in range(lt)]
            tups = [(model_dict[line['task_type']], line) for line in lines]
            indices = [line['index'] for line in lines]

            ans = {}
            if osp.exists(tmp_file):
                ans = load(tmp_file)
            tups = [x for x, i in zip(tups, indices) if i not in ans]
            indices = [i for i in indices if i not in ans]

            if len(indices):
                _ = track_progress_rich(
                    MLVU_OpenEnded_generate,
                    tups,
                    nproc=nproc,
                    chunksize=nproc,
                    keys=indices,
                    save=tmp_file,
                )
            ans = load(tmp_file)
            data = MLVU_OpenEnded_extract(ans, data)
            dump(data, score_file)

        rating = get_dimension_rating(score_file)
        return rating
