from dataset.base import BaseDataset

import logging
import os
import random

import ndjson
import einops
import numpy as np
import torch


class MovieNetDataset(BaseDataset):
    
    def __init__(self, cfg, mode, is_train, is_test):
        
        super(MovieNetDataset, self).__init__(cfg, mode, is_train, is_test)

        # logging.info(f"Load Dataset: {cfg.DATASET}")

        if mode == "finetune" and not self.use_raw_shot:
            assert len(self.cfg.PRETRAINED_LOAD_FROM) > 0
            self.shot_repr_dir = os.path.join(
                self.cfg.FEAT_PATH, self.cfg.PRETRAINED_LOAD_FROM
            )

        sampling_name = cfg.LOSS.sampling_method.name
        if sampling_name == "asymmetric":
            self.cidx = cfg.LOSS.sampling_method.params[sampling_name]["neighbor_left"]
        else:
            self.cidx = cfg.LOSS.sampling_method.params[sampling_name]["neighbor_size"]

        self.activate_nearby_shots = cfg.LOSS.get("activate_nearby_shots", False)
        self.first_shot_prediction = cfg.LOSS.get("first_shot_prediction", False)
        self.reverse_shot_prediction = cfg.LOSS.get("reverse_shot_prediction", False)

        self.use_duration = cfg.LOSS.sampling_method.get("use_duration", False)
        self.anchor_sample_type = cfg.LOSS.sampling_method.get("anchor_sample_type", "short_weighted")
        self.use_random = cfg.LOSS.sampling_method.get("use_random", False)

        logging.info(f"use_duration: {self.use_duration}")
        logging.info(f"anchor_sample_type: {self.anchor_sample_type}")
        logging.info(f"use_random: {self.use_random}")

    
    def get_ndjson(self, path):
        # logging.info("ndjson path: {}".format(path))
        with open(path, "r") as f:
            data = ndjson.load(f)

        return data


    def load_data(self):

        self.tmpl = "{}/shot_{}_img_{}.jpg"  # video_id, shot_id, shot_num

        if self.mode == "extract":
            # data_path = os.path.join(self.cfg.ANNO_PATH, "anno.trainvaltest.ndjson")
            data_path = os.path.join(self.cfg.ANNO_PATH, "anno.trainvaltest_v2.ndjson")
        elif self.mode == "pretrain":
            if self.is_train:
                if self.cfg.NUM_MOVIES == 1100:
                    # data_path = os.path.join(self.cfg.ANNO_PATH, "anno.pretrain.ndjson")
                    data_path = os.path.join(self.cfg.ANNO_PATH, "anno.pretrain_v2.ndjson")
                elif self.cfg.NUM_MOVIES == 782: ### 782 = 1000 - 318(190+64+64) 
                    data_path = os.path.join(self.cfg.ANNO_PATH, "anno.pretrain_782.ndjson")
                elif self.cfg.NUM_MOVIES == 972: ### 782 = 1000 - 128(64+64)
                    data_path = os.path.join(self.cfg.ANNO_PATH, "anno.pretrain_972.ndjson")
                else:
                    raise NotImplementedError
            else:
                if self.is_test:
                    data_path = os.path.join(self.cfg.ANNO_PATH, "anno.test.ndjson")
                else:
                    data_path = os.path.join(self.cfg.ANNO_PATH, "anno.val.ndjson")
        elif self.mode == "finetune":
            if self.is_train:
                # data_path = os.path.join(self.cfg.ANNO_PATH, "anno.train.ndjson")
                data_path = os.path.join(self.cfg.ANNO_PATH, "anno.train_v2.ndjson")
            else:
                if self.is_test:
                    # data_path = os.path.join(self.cfg.ANNO_PATH, "anno.test.ndjson")
                    data_path = os.path.join(self.cfg.ANNO_PATH, "anno.test_v2.ndjson")
                else:
                    # data_path = os.path.join(self.cfg.ANNO_PATH, "anno.val.ndjson")
                    data_path = os.path.join(self.cfg.ANNO_PATH, "anno.val_v2.ndjson")
                    # data_path = os.path.join(self.cfg.ANNO_PATH, "anno.test_v2.ndjson")
            self.use_raw_shot = self.cfg.USE_RAW_SHOT
            if not self.use_raw_shot:
                self.tmpl = "{}/shot_{}.npy"  # video_id, shot_id                    
        else:
            raise NotImplementedError

        self.anno_data = self.get_ndjson(data_path)


    def _getitem_for_pretrain(self, idx: int):

        data = self.anno_data[idx] # contain {"video_id", "shot_id", "num_shot"}

        vid = data["video_id"]
        sid = data["shot_id"]
        num_shot = data["num_shot"]
        payload = {"idx": idx, "vid": vid, "sid": sid}
    
        if self.sampling_method in ["instance", "temporal"]:
            # This is for two shot-level pre-training baselines:
            # 1) SimCLR (instance) and 2) SimCLR (temporal)
            keyframes, nshot = self.load_shot(vid, sid)
            view1 = self.apply_transform(keyframes)
            view1 = einops.rearrange(view1, "(s k) c ... -> s (k c) ...", s=nshot)

            new_sid = self.shot_sampler(int(sid), num_shot)
            if not new_sid == int(sid):
                keyframes, nshot = self.load_shot(vid, sid)
            view2 = self.apply_transform(keyframes)
            view2 = einops.rearrange(view2, "(s k) c ... -> s (k c) ...", s=nshot)

            # video shape: [nView=2,S,C,H,W]
            video = torch.stack([view1, view2])
            payload["video"] = video
            
        elif self.sampling_method in ["shotcol", "bassl+shotcol", "bassl", "asymmetric"]:
            sparse_method = "edge" if self.sampling_method in ["bassl", "asymmetric"] else "edge+center"
            sparse_idx_to_dense, dense_idx = self.shot_sampler(
                int(sid), num_shot, sparse_method=sparse_method
            )

            if self.use_duration:
                shift = [idx+i-dense_idx[self.cidx] for i in dense_idx]
                dense_duration = np.array([self.anno_data[di]["length"] for di in shift])

                if self.anchor_sample_type == "short_weighted":
                    
                    left_inv = 1. / dense_duration[:self.cidx]
                    right_inv = 1. / dense_duration[self.cidx+1:]    
                    left_prob = left_inv / np.sum(left_inv)
                    right_prob = right_inv / np.sum(right_inv)
                    left_idx = np.random.choice(len(left_prob), 1, p=left_prob)
                    right_idx = np.random.choice(len(right_prob), 1, p=right_prob)
                    sparse_idx_to_dense = np.concatenate((left_idx, right_idx+self.cidx+1), axis=None)

                elif self.anchor_sample_type == "short_fixed":
                    left_inv = 1. / dense_duration[:self.cidx]
                    right_inv = 1. / dense_duration[self.cidx+1:]    
                    left_prob = left_inv / np.sum(left_inv)
                    right_prob = right_inv / np.sum(right_inv)
                    left_idx = np.argsort(left_prob)[-1]
                    right_idx = np.argsort(right_prob)[-1]
                    sparse_idx_to_dense = np.concatenate((left_idx, right_idx+self.cidx+1), axis=None)


                elif self.anchor_sample_type == "long_weighted":
                    left_prob = dense_duration[:self.cidx] / np.sum(dense_duration[:self.cidx])
                    right_prob = dense_duration[self.cidx+1:] / np.sum(dense_duration[self.cidx+1:])
                    left_idx = np.random.choice(len(left_prob), 1, p=left_prob)
                    right_idx = np.random.choice(len(right_prob), 1, p=right_prob)
                    sparse_idx_to_dense = np.concatenate((left_idx, right_idx+self.cidx+1), axis=None)


                elif self.anchor_sample_type == "long_fixed":
                    left_idx = np.argsort(dense_duration[:self.cidx])[-1]
                    right_idx = np.argsort(dense_duration[self.cidx+1:])[-1]
                    sparse_idx_to_dense = np.array([left_idx, right_idx+self.cidx+1])


                else:
                    raise NotImplementedError
            

            elif self.use_random:
                sparse_idx_to_dense = np.sort(np.random.choice(len(dense_idx), 2, replace=False))

            _dense_video, dense_place, dense_audio = self.load_shot_list(vid, dense_idx)
            dense_video = self.apply_transform(_dense_video)
            dense_video = dense_video.view(len(dense_idx), self.num_keyframe, -1, 224, 224)

            _sparse_video = [_dense_video[idx*self.num_keyframe+i] for idx in sparse_idx_to_dense for i in range(self.num_keyframe)]
            sparse_video = self.apply_transform(_sparse_video)
            sparse_video = sparse_video.view(len(sparse_idx_to_dense), self.num_keyframe, -1, 224, 224)

            video = torch.cat([sparse_video, dense_video], dim=0)
            
            if 'PLACE' in self.cfg.OTHER_MODALITY.TYPE:
                dense_place = torch.from_numpy(dense_place)
                sparse_place = dense_place[sparse_idx_to_dense]
                place = torch.cat([sparse_place, dense_place], dim=0)  #[19, 2048]
                payload["place"] = place

            if 'AUDIO' in self.cfg.OTHER_MODALITY.TYPE:
                dense_audio = torch.from_numpy(dense_audio)
                sparse_audio = dense_audio[sparse_idx_to_dense]
                audio = torch.cat([sparse_audio, dense_audio], dim=0)  # [19, 2048]
                payload["audio"] = audio

            payload["video"] = video
            payload["sparse_idx"] = sparse_idx_to_dense
            payload["dense_idx"] = dense_idx
            payload["mask"] = self._get_mask(len(dense_idx))

        assert "video" in payload
        return payload

    
    def _getitem_for_knn_val(self, idx: int):
        
        data = self.anno_data[idx] # contain {"video_id", "shot_id", "num_shot"}

        vid = data["video_id"]
        sid = data["shot_id"]
        num_shot = data["num_shot"]
        payload = {
            "global_video_id": data["global_video_id"],
            "sid": sid,
            "invideo_scene_id": data["invideo_scene_id"],
            "global_scene_id": data["global_scene_id"],
        }

        sparse_idx, dense_idx = self.shot_sampler(int(sid), num_shot)
        video, place, audio = self.load_shot_list(vid, dense_idx)

        video = self.apply_transform(video)
        video = einops.rearrange(video, "(s k) c ... -> s k c ...", s=len(dense_idx))

        payload["video"] = video
        payload["place"] = place
        payload["audio"] = audio

        assert "video" in payload
        return payload

    
    def _getitem_for_extract_shot(self, idx: int):

        data = self.anno_data[idx] # contain {"video_id", "shot_id", "num_shot"}

        vid = data["video_id"]
        sid = data["shot_id"]
        num_shot = data["num_shot"]
        payload = {"vid": vid, "sid": sid}

        sparse_idx, dense_idx = self.shot_sampler(int(sid), num_shot)
        video, place, audio = self.load_shot_list(vid, dense_idx)

        video = self.apply_transform(video)
        video = einops.rearrange(video, "(s k) c ... -> s k c ...", s=len(dense_idx))

        payload["video"] = video
        payload["place"] = place
        payload["audio"] = audio

        assert "video" in payload

        if self.use_duration:
            shift = [idx+i-dense_idx[self.cidx] for i in dense_idx]
            dense_duration = np.array([self.anno_data[di]["length"] for di in shift])
            if self.anchor_sample_type == "short_weighted":
                
                left_inv = 1. / dense_duration[:self.cidx]
                right_inv = 1. / dense_duration[self.cidx+1:]    
                left_prob = left_inv / np.sum(left_inv)
                right_prob = right_inv / np.sum(right_inv)
                left_idx = np.random.choice(len(left_prob), 1, p=left_prob)
                right_idx = np.random.choice(len(right_prob), 1, p=right_prob)
                sparse_idx_to_dense = np.concatenate((left_idx, right_idx+self.cidx+1), axis=None)

                payload["sparse_idx"] = sparse_idx_to_dense
        
        return payload

    
    def _getitem_for_finetune(self, idx: int):
        
        data = self.anno_data[idx] # contain {"video_id", "shot_id", "num_shot"}

        vid = data["video_id"]
        sid = data["shot_id"]
        num_shot = data["num_shot"]

        if self.cfg.LOSS.sampling_method.name=="asymmetric":
            _, shot_idx = self.shot_sampler(int(sid), num_shot)
        elif self.cfg.LOSS.sampling_method.name=="sbd":
            shot_idx = self.shot_sampler(int(sid), num_shot)
        else:
            raise NotImplementedError

        if self.use_raw_shot:
            video, place, audio = self.load_shot_list(vid, shot_idx)
            video = self.apply_transform(video)
            video = video.view(len(shot_idx), self.num_keyframe, 3, 224, 224)
        else:
            shot_feat_path = os.path.join(self.shot_repr_dir, self.tmpl.format(vid, sid))
            shot = np.load(shot_feat_path)
            video = torch.from_numpy(shot)
            place = audio = []

        payload = {
            "idx": idx,
            "vid": vid,
            "sid": sid,
            "video": video,
            "place": place,
            "audio": audio,
            "label": abs(data["boundary_label"]),  # ignore -1 label.
        }

        if self.activate_nearby_shots:
            payload["first"] = data["boundary_first"]
            
            prev_idx = shot_idx[self.cidx - 1]
            payload["label_prev"] = abs(self.anno_data[prev_idx]["boundary_label"])
            
            next_idx = shot_idx[self.cidx + 1]
            payload["first_next"] = abs(self.anno_data[next_idx]["boundary_first"])

        if self.first_shot_prediction or self.reverse_shot_prediction:
            payload["first"] = data["boundary_first"]

        return payload
    
    def _getitem_for_sbd_eval(self, idx: int):
        return self._getitem_for_finetune(idx)

    
    def __getitem__(self, idx: int):
        if self.mode == "extract":
            return self._getitem_for_extract_shot(idx)

        elif self.mode == "pretrain":
            if self.is_train:
                return self._getitem_for_pretrain(idx)
            else:
                return self._getitem_for_knn_val(idx)

        elif self.mode == "finetune":
            if self.is_train:
                return self._getitem_for_finetune(idx)
            else:
                return self._getitem_for_sbd_eval(idx)    

    def _get_mask(self, N: int):
        mask = np.zeros(N).astype(np.float16)

        for i in range(N):
            prob = random.random()
            # mask token with 15% probability
            if prob < 0.15:
                mask[i] = 1.0

        if (mask == 0).all():
            # at least mask 1
            ridx = random.choice(list(range(0, N)))
            mask[ridx] = 1.0
        return mask