from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import json
import tempfile
import pandas as pd
from os.path import join, splitext, exists
from collections import OrderedDict
from .dataloader_retrieval import RetrievalDataset


class MSRVTTDataset(RetrievalDataset):
    """MSRVTT dataset."""

    def __init__(self, subset, anno_path, video_path, tokenizer, max_words=32,
                 max_frames=12, video_framerate=1, image_resolution=224, mode='all', config=None):
        super(MSRVTTDataset, self).__init__(subset, anno_path, video_path, tokenizer, max_words,
                                            max_frames, video_framerate, image_resolution, mode, config=config)
        pass


    def _get_anns(self, subset='train'):
        """
        video_dict: dict: video_id -> video_path
        sentences_dict: list: [(video_id, caption)] , caption (list: [text:, start, end])
        """
        csv_path = {'train': join(self.anno_path, 'MSRVTT_train.9k.csv'),
                    'val': join(self.anno_path, 'MSRVTT_test_text.csv'), 
                    'test': join(self.anno_path, 'MSRVTT_test_text.csv')}[subset]
        if exists(csv_path):
            csv = pd.read_csv(csv_path)
        else:
            raise FileNotFoundError

        video_id_list = list(csv['video_id'].values)

        video_dict = OrderedDict()
        sentences_dict = OrderedDict()
        if subset == 'train':
            anno_path = join(self.anno_path, 'msrvtt_train_text.json') 
            data = json.load(open(anno_path, 'r'))

            
            title_list = []
            for k, v in data['title'].items():
                if isinstance(v,list):
                    for i in range(len(v)):
                        title_list.append({'video_id': k, 'caption': v[i]})
                else:
                    title_list.append({'video_id': k, 'caption': v})
            data['sentences'].extend(title_list)

            titles = data['titles']  

            for itm in data['sentences']:
                if itm['video_id'] in video_id_list:
                    video_id = itm['video_id']
                    t_datas = []
                    if video_id in titles:
                        captions = titles[video_id]
                        if isinstance(captions, list): 
                            for cap in captions:
                                t_datas.append((cap,None,None))
                            while len(t_datas) < 30:
                                t_datas.append((None,None,None))
                            t_datas = t_datas[:30] 

                    sentences_dict[len(sentences_dict)] = (itm['video_id'], (itm['caption'], None, None), t_datas) 
                    video_dict[itm['video_id']] = join(self.video_path, "{}.mp4".format(itm['video_id']))
        else:
            for _, itm in csv.iterrows():
                t_datas = []
                for i in range(30):
                    caption_key = f'fast_caption_{i}'
                    if caption_key in itm:
                        t_datas.append((itm[caption_key], None, None))

                sentences_dict[len(sentences_dict)] = (itm['video_id'], (itm['sentence'], None, None), t_datas) 
                video_dict[itm['video_id']] = join(self.video_path, "{}.mp4".format(itm['video_id']))

        unique_sentence = set([v[1][0] for v in sentences_dict.values()])
        print('[{}] Unique sentence is {} , all num is {}'.format(subset, len(unique_sentence), len(sentences_dict)))

        return video_dict, sentences_dict
