from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import os
from torch.utils.data import Dataset
import numpy as np
import json
import math
from .rawvideo_util import RawVideoExtractor


class ActivityNet_DataLoader(Dataset):
    def __init__(
        self,
        subset,
        data_path,
        features_path,
        tokenizer,
        max_words=30,
        feature_framerate=1.0,
        max_frames=100,
        image_resolution=224,
        frame_order=0,
        slice_framepos=0,
    ):
        self.data_path = data_path
        self.features_path = features_path
        self.feature_framerate = feature_framerate
        self.max_words = max_words
        self.max_frames = max_frames
        self.tokenizer = tokenizer
        # 0: ordinary order; 1: reverse order; 2: random order.
        self.frame_order = frame_order
        assert self.frame_order in [0, 1, 2]
        # 0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.
        self.slice_framepos = slice_framepos
        assert self.slice_framepos in [0, 1, 2]

        self.subset = subset
        assert self.subset in ["train", "val"]

        video_id_path_dict = {}
        video_id_path_dict["train"] = os.path.join(self.data_path, "train_ids.json")
        video_id_path_dict["val"] = os.path.join(self.data_path, "val_ids.json")

        video_json_path_dict = {}
        video_json_path_dict["train"] = os.path.join(self.data_path, "train.json")
        video_json_path_dict["val"] = os.path.join(self.data_path, "val_1.json")

        pseudo_video_id_list, video_id_list = self._get_video_id_single(
            video_id_path_dict[self.subset]
        )
        pseudo_caption_dict = self._get_captions_single(
            video_json_path_dict[self.subset]
        )

        print("video id list: {}".format(len(video_id_list)))
        print("pseudo caption dict: {}".format(len(pseudo_caption_dict.keys())))

        video_dict = {}
        for root, dub_dir, video_files in os.walk(self.features_path):
            for video_file in video_files:
                video_id_ = ".".join(video_file.split(".")[:-1])
                if video_id_ not in video_id_list:
                    continue
                file_path_ = os.path.join(root, video_file)
                video_dict[video_id_] = file_path_
        self.video_dict = video_dict
        print("video dict: {}".format(len(video_dict)))

        self.pseudo_video_id_list = pseudo_video_id_list
        self.video_id_list = video_id_list
        self.pseudo_caption_dict = pseudo_caption_dict

        # Get iterator video ids
        self.video_id2idx_dict = {
            pseudo_video_id: id
            for id, pseudo_video_id in enumerate(self.pseudo_video_id_list)
        }
        # Get all captions
        self.iter2video_pairs_dict = {}
        for pseudo_video_id, video_id in zip(
            self.pseudo_video_id_list, self.video_id_list
        ):
            if (
                pseudo_video_id not in self.pseudo_caption_dict
                or video_id not in self.video_dict
            ):
                continue
            caption = self.pseudo_caption_dict[pseudo_video_id]
            n_caption = len(caption["start"])
            for sub_id in range(n_caption):
                self.iter2video_pairs_dict[len(self.iter2video_pairs_dict)] = (
                    pseudo_video_id,
                    sub_id,
                )

        self.rawVideoExtractor = RawVideoExtractor(
            framerate=feature_framerate, size=image_resolution
        )
        self.SPECIAL_TOKEN = {
            "CLS_TOKEN": "<|startoftext|>",
            "SEP_TOKEN": "<|endoftext|>",
            "MASK_TOKEN": "[MASK]",
            "UNK_TOKEN": "[UNK]",
            "PAD_TOKEN": "[PAD]",
        }

    def __len__(self):
        return len(self.iter2video_pairs_dict)

    def _get_video_id_from_pseduo(self, pseudo_video_id):
        video_id = pseudo_video_id[2:]
        return video_id

    def _get_video_id_single(self, path):
        pseudo_video_id_list = []
        video_id_list = []
        print("Loading json: {}".format(path))
        with open(path, "r") as f:
            json_data = json.load(f)

        for pseudo_video_id in json_data:
            if pseudo_video_id in pseudo_video_id_list:
                print("reduplicate.")
            else:
                video_id = self._get_video_id_from_pseduo(pseudo_video_id)
                pseudo_video_id_list.append(pseudo_video_id)
                video_id_list.append(video_id)
        return pseudo_video_id_list, video_id_list

    def _get_captions_single(self, path):
        pseudo_caption_dict = {}
        with open(path, "r") as f:
            json_data = json.load(f)

        for pseudo_video_id, v_ in json_data.items():
            pseudo_caption_dict[pseudo_video_id] = {}
            duration = v_["duration"]
            pseudo_caption_dict[pseudo_video_id]["start"] = np.array([0], dtype=object)
            pseudo_caption_dict[pseudo_video_id]["end"] = np.array(
                [int(math.ceil(float(duration)))], dtype=object
            )
            pseudo_caption_dict[pseudo_video_id]["text"] = np.array(
                [" ".join(v_["sentences"])], dtype=object
            )
        return pseudo_caption_dict

    def _get_text(self, pseudo_video_id, sub_id):
        caption = self.pseudo_caption_dict[pseudo_video_id]
        k = 1
        r_ind = [sub_id]

        starts = np.zeros(k, dtype=np.long)
        ends = np.zeros(k, dtype=np.long)
        pairs_text = np.zeros((k, self.max_words), dtype=np.long)
        pairs_mask = np.zeros((k, self.max_words), dtype=np.long)
        pairs_segment = np.zeros((k, self.max_words), dtype=np.long)

        for i in range(k):
            ind = r_ind[i]
            start_, end_ = caption["start"][ind], caption["end"][ind]
            words = self.tokenizer.tokenize(caption["text"][ind])
            starts[i], ends[i] = start_, end_

            words = [self.SPECIAL_TOKEN["CLS_TOKEN"]] + words
            total_length_with_CLS = self.max_words - 1
            if len(words) > total_length_with_CLS:
                words = words[:total_length_with_CLS]
            words = words + [self.SPECIAL_TOKEN["SEP_TOKEN"]]

            input_ids = self.tokenizer.convert_tokens_to_ids(words)
            input_mask = [1] * len(input_ids)
            segment_ids = [0] * len(input_ids)
            while len(input_ids) < self.max_words:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)
            assert len(input_ids) == self.max_words
            assert len(input_mask) == self.max_words
            assert len(segment_ids) == self.max_words

            pairs_text[i] = np.array(input_ids)
            pairs_mask[i] = np.array(input_mask)
            pairs_segment[i] = np.array(segment_ids)

        return pairs_text, pairs_mask, pairs_segment, starts, ends

    def _get_rawvideo(self, idx, s, e):
        video_mask = np.zeros((len(s), self.max_frames), dtype=np.long)
        max_video_length = [0] * len(s)

        # Pair x L x T x 3 x H x W
        video = np.zeros(
            (
                len(s),
                self.max_frames,
                1,
                3,
                self.rawVideoExtractor.size,
                self.rawVideoExtractor.size,
            ),
            dtype=np.float,
        )
        video_path = self.video_dict[idx]
        try:
            for i in range(len(s)):
                start_time = int(s[i])
                end_time = int(e[i])
                start_time = start_time if start_time >= 0.0 else 0.0
                end_time = end_time if end_time >= 0.0 else 0.0
                if start_time > end_time:
                    start_time, end_time = end_time, start_time
                elif start_time == end_time:
                    end_time = end_time + 1

                # Should be optimized by gathering all asking of this video
                raw_video_data = self.rawVideoExtractor.get_video_data(
                    video_path, start_time, end_time
                )
                raw_video_data = raw_video_data["video"]

                if len(raw_video_data.shape) > 3:
                    raw_video_data_clip = raw_video_data
                    # L x T x 3 x H x W
                    raw_video_slice = self.rawVideoExtractor.process_raw_data(
                        raw_video_data_clip
                    )
                    if self.max_frames < raw_video_slice.shape[0]:
                        if self.slice_framepos == 0:
                            video_slice = raw_video_slice[: self.max_frames, ...]
                        elif self.slice_framepos == 1:
                            video_slice = raw_video_slice[-self.max_frames :, ...]
                        else:
                            sample_indx = np.linspace(
                                0,
                                raw_video_slice.shape[0] - 1,
                                num=self.max_frames,
                                dtype=int,
                            )
                            video_slice = raw_video_slice[sample_indx, ...]
                    else:
                        video_slice = raw_video_slice

                    video_slice = self.rawVideoExtractor.process_frame_order(
                        video_slice, frame_order=self.frame_order
                    )

                    slice_len = video_slice.shape[0]
                    max_video_length[i] = (
                        max_video_length[i]
                        if max_video_length[i] > slice_len
                        else slice_len
                    )
                    if slice_len < 1:
                        pass
                    else:
                        video[i][:slice_len, ...] = video_slice
                else:
                    print(
                        "video path: {} error. video id: {}, start: {}, end: {}".format(
                            video_path, idx, start_time, end_time
                        )
                    )
        except Exception as excep:
            print(
                "video path: {} error. video id: {}, start: {}, end: {}, Error: {}".format(
                    video_path, idx, s, e, excep
                )
            )
            raise excep

        for i, v_length in enumerate(max_video_length):
            video_mask[i][:v_length] = [1] * v_length

        return video, video_mask

    def __getitem__(self, feature_idx):
        pseudo_video_id, sub_id = self.iter2video_pairs_dict[feature_idx]
        idx = self.video_id2idx_dict[pseudo_video_id]

        pairs_text, pairs_mask, pairs_segment, starts, ends = self._get_text(
            pseudo_video_id, sub_id
        )
        video, video_mask = self._get_rawvideo(self.video_id_list[idx], starts, ends)
        return pairs_text, pairs_mask, pairs_segment, video, video_mask
