# from dataset.base import BaseDataset
from dataset.movienet import MovieNetDataset

import logging
import os
import random

import ndjson
import einops
import numpy as np
import torch


class TestDataset(MovieNetDataset):
    
    def __init__(self, cfg, mode, is_train, is_test):
        
        super(TestDataset, self).__init__(cfg, mode, is_train, is_test)


    def load_data(self):

        self.tmpl = "{}/shot_{}_img_{}.jpg"  # video_id, shot_id, shot_num

        if self.mode == "extract":
            data_path = os.path.join(self.cfg.ANNO_PATH, "anno.test.ndjson")
        elif self.mode == "pretrain":
            raise NotImplementedError
            if self.is_train:
                raise NotImplementedError
            else:
                if self.is_test:
                    data_path = os.path.join(self.cfg.ANNO_PATH, "anno.test.ndjson")
                else:
                    raise NotImplementedError
        elif self.mode == "finetune":
            if self.is_train:
                raise NotImplementedError
            else:
                if self.is_test:
                    data_path = os.path.join(self.cfg.ANNO_PATH, "anno.test.ndjson")
                else:
                    raise NotImplementedError
            self.use_raw_shot = self.cfg.USE_RAW_SHOT
            if not self.use_raw_shot:
                self.tmpl = "{}/shot_{}.npy"  # video_id, shot_id                    
        else:
            raise NotImplementedError

        self.anno_data = self.get_ndjson(data_path)