import os
import os.path as osp
import numpy as np
from PIL import Image
from datasets.utils import resize_and_pad, resize
import collections
import torch
import torchvision
from torch.utils import data
from bert_embedding import BertEmbedding
import scipy.io
import glob, pdb
import time
import cv2
import random
import csv
import json
import pickle
from tqdm import tqdm
# import torch.nn.functional as F
import torchvision.transforms.functional as F

from itertools import chain, combinations

import tqdm

import matplotlib.pyplot as plt


def _flip_axis(x, axis):
    x = np.asarray(x).swapaxes(axis, 0)
    x = x[::-1, ...]
    x = x.swapaxes(0, axis)
    return x


import pdb
from torch.utils import data
from pathlib import Path
from time import time

import os

import cv2
import numpy as np
import scipy.io as scio
import torch
import json
import nltk
from PIL import Image
from gensim.models import KeyedVectors
from nltk import word_tokenize
from nltk.corpus import stopwords, wordnet
# from bert_embedding import BertEmbedding
from skimage import io
from skimage.segmentation import slic
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms


def load_json(filename):
    with open(filename, encoding='utf8') as fr:
        return json.load(fr)


class RefSubset(Dataset):
    def __init__(self, word2vec, args, train=False):
        # self.videos = videos
        # self.samples = samples
        self.word2vec = word2vec
        self.args = args
        self.train = train
        self.data_root = '/home/user/data/rvos'
        if self.train:
            self.split = 'train'
        else:
            self.split = 'valid'

        self.image_dir = self.data_root + '/' + self.split + '/JPEGImages'
        self.mask_dir = self.data_root + '/' + self.split + '/Annotations'
        self.set_meta_file()

        self.N = 8
        self.max_skip = 1

        # self.eval = False
        self.resolution = [10, 20, 40, 80, 160, 320]
        self.n_segments = {
            self.resolution[0]: 16,
            self.resolution[1]: 32,
            self.resolution[2]: 32,
            self.resolution[3]: 32,
            self.resolution[4]: 32,
            self.resolution[5]: 32,
        }

    def set_meta_file(self):

        mymeta_path = self.data_root + '/' + self.split + '/mymeta.pkl'
        if os.path.exists(mymeta_path):
            with open(mymeta_path, 'rb') as f:
                self.videos = pickle.load(f)
        else:
            data = json.load(open('/home/user/data/rvos/' + self.split + '/meta_expressions.json'))
            category_data = json.load(open('/home/user/data/rvos/' + self.split + "/meta.json"))
            self.videos = []
            for vid, objs in tqdm.tqdm(data['videos'].items(), desc='Data processing'):
                # vid = '0a2f2bd294'
                # objs = data['videos'][vid]
                if not self.train:
                    for obj_id, obj in objs['expressions'].items():
                        # oid = int(obj_id)

                        sents = obj['exp']
                        if len(sents) > 0:
                            self.videos.append([vid, obj_id, objs['frames'], sents])


                else:
                    for obj_id, obj in objs['expressions'].items():
                        oid = int(obj['obj_id'])
                        object_id = obj['obj_id']
                        sents = obj['exp']

                        same_video = category_data['videos'][vid]
                        same_object = same_video['objects'][object_id]
                        if len(sents) == 0:
                            print("Not included (no sents): ", vid, oid)

                        if same_object['category'] == "person" or same_object['category'] == 'bird' or same_object['category'] == 'dog'\
                                or same_object['category'] == 'cat' or same_object['category'] == 'car':

                            anker = []
                            for frm in objs['frames']:
                                mask_name = self.data_root + '/' + self.split + '/Annotations/' + vid + '/' + frm + '.png'
                                mask = np.uint8(Image.open(mask_name).convert('P'))
                                mask = np.uint8(mask == oid)
                                if float(mask.sum()) / mask.size > np.square(0.02):
                                    anker += [1]
                                else:
                                    anker += [0]

                            if sum(anker) >= 3:
                                print("Add")
                                # for sent in sents:
                                self.videos.append([vid, oid, objs['frames'], anker, sents])
                            else:
                                print("Not included : ", vid, oid, anker)



            with open(mymeta_path, 'wb') as f:
                pickle.dump(self.videos, f, pickle.HIGHEST_PROTOCOL)

        len_videos = len(self.videos)
        # if self.scale < 1.0:
        #     len_videos = int(len_videos * self.scale)
        self.videos = self.videos[:len_videos]

    def load_pair(self, vid, oid, fid, flip):
        img_name = self.data_root + '/' + self.split + '/JPEGImages/' + vid + '/{}.jpg'.format(fid)

        # frame = np.float32(Image.open(img_name).convert('RGB')) # / 255.
        frame = cv2.imread(img_name)
        image = frame[:, :, [2, 1, 0]]
        image = Image.fromarray(image)

        if flip:
            image = np.asarray(F.hflip(image))
        else:
            image = np.asarray(image)

        image = np.asarray(image)
        # image_norm = 2 * (image / 255.0) - 1
        image, valid_mask, _ = resize_and_pad(image, limit_size=320)
        # image_norm = 2 * ((image - image.min() + 1e-10) / (image.max() - image.min() + 1e-10)) - 1
        # image_norm = (image - mean_pix[np.newaxis, np.newaxis, :]) / std_pix[np.newaxis, np.newaxis, :]
        image_norm = 2 * (image / 255.0) - 1
        # image_norm = image_norm * valid_mask[:, :, np.newaxis]
        # image_norm = image / 255.0
        frame = image_norm

        if self.train:
            mask_name = self.data_root + '/' + self.split + '/Annotations/' + vid + '/{}.png'.format(fid)

            mask = np.uint8(Image.open(mask_name).convert('P'))
            mask = np.uint8(mask == oid)
            return frame, mask
        else:
            return frame

    def load_pairs(self, vid, frame_ids):
        frames, masks = [], []
        for frame_id in frame_ids:
            # obj_id = 1
            frame = self.load_pair(vid, 1, frame_id)
            # frame = self.resize(frame, None, self.size)
            frames.append(frame)
            # masks.append(mask)

        N_frames = np.stack(frames, axis=0)
        # N_masks = np.stack(masks, axis=0)

        Fs = torch.from_numpy(np.transpose(N_frames, (0, 3, 1, 2)).copy()).float()
        # Ms = torch.from_numpy(N_masks.copy()).float()
        return Fs  # , Ms

    def sample_frame_ids_base(self, frame_ids, anker, rnd, vid, oid):

        n_images = len(frame_ids)
        tt = 0
        while True:
            sample_skips = [rnd.randint(1, self.max_skip) for _ in range(1, self.N)]
            if sum(sample_skips) < n_images:
                break
            if tt > 100:
                sample_skips = [1] * (self.N - 1)
                break
            tt = tt + 1

        use_ids = [None] * self.N
        # start index
        anker_nnz = [i for i, e in enumerate(anker) if e != 0]
        n_skip = sum(sample_skips)

        if True:  # always forward
            anker_idx = [i for i in anker_nnz if i + (self.N // 2) <= n_images]  # 3 + 8/2 = 7
            if len(anker_idx) > 0:
                use_ids[0] = anker_idx[rnd.randint(0, len(anker_idx) - 1)]
            else:
                use_ids[0] = anker_nnz[0]

        frame_idx = use_ids[0]
        # for i in range(1, self.N):
        #     use_ids[i] = (use_ids[i - 1] + sample_skips[i - 1]) % n_images

        use_ids = [i for i in range(frame_idx - self.N // 2 * 1 + 1, frame_idx + self.N // 2 * 1 + 1, 1)]
        for i in range(len(use_ids)):
            if use_ids[i] < 0:
                use_ids[i] = 0
            elif use_ids[i] >= n_images:
                use_ids[i] = n_images - 1
        # all_frames = np.asarray(frames)[all_frames]
        # all_frames1 = []
        use_frame_ids = [frame_ids[i] for i in use_ids]

        target_frame_idx = frame_ids[frame_idx]
        return use_frame_ids, target_frame_idx

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, index):
        is_full = 1
        if not self.train:
            vid, oid, frame_ids, sent = self.videos[index]

            # Fs = self.load_pairs(vid, frame_ids)

            # num_frames = len(Fs)
            feats = []
            all_segments = {}
            mask = {}
            for i in range(len(frame_ids)):
                curr_id = i
                target_idx = frame_ids[i]
                target_frame = io.imread(self.image_dir + '/' + vid + '/' + target_idx + '.jpg')
                for r in self.resolution:
                    all_segments[r], mask[r] = [], []
                    a, b, pad = resize_and_pad(target_frame * 255, r, interpolation=cv2.INTER_AREA)
                    cur_frame = resize(target_frame, r)
                    mask[r].append(b)
                    mask[r] = torch.from_numpy(np.asarray(mask[r])).long()
                    # mask[r] = mask[r].expand(len(frame_ids), -1, -1, -1)
                    segment = slic(cur_frame, n_segments=self.n_segments[r] - 1) + 1

                    if r - cur_frame.shape[0] - pad < 0:
                        pad_config = [[0, 0], [pad, r - cur_frame.shape[1] - pad]]
                    else:
                        pad_config = [[pad, r - cur_frame.shape[0] - pad], [0, 0]]
                    segment = np.pad(segment, pad_config, mode='constant', constant_values=0)
                    all_segments[r].append(segment)
                    all_segments[r] = torch.from_numpy(np.asarray(all_segments[r])).long()
                    # all_segments[r] = all_segments[r].expand(len(frame_ids), -1, -1, -1)

                use_ids = [index for index in range(curr_id - self.N // 2 * 1, curr_id + self.N // 2 * 1, 1)]
                for j in range(len(use_ids)):
                    if use_ids[j] < 0:
                        use_ids[j] = 0
                    elif use_ids[j] >= len(frame_ids):
                        use_ids[j] = len(frame_ids) - 1
                # all_frames = np.asarray(frames)[all_frames]
                # all_frames1 = []
                use_frame_ids = [frame_ids[i] for i in use_ids]
                all_frames = np.asarray(use_frame_ids)
                all_frames1 = []

                for j in all_frames:
                    # print(i)
                    #  frame = cv2.imread(i)
                    frame = io.imread(self.image_dir + '/' + vid + '/' + j + '.jpg')
                    # image = frame[:, :, [2, 1, 0]]
                    image = frame
                    image = Image.fromarray(image)
                    image = np.asarray(image)
                    image, valid_mask, _ = resize_and_pad(image, limit_size=320)

                    image_norm = 2 * (image / 255.0) - 1
                    # image_norm = image_norm * valid_mask[:, :, np.newaxis] before
                    # image_norm = image / 255.0
                    all_frames1.append(image_norm)

                feat = np.asarray(all_frames1).transpose((3, 0, 1, 2))
                feats.append(feat)
                # target_frame_idx = frame_ids[frame_idx]
            feats = np.asarray(feats).astype(np.float32)  # .transpose(0, 1)

            # query = [word.lower() for word in word_tokenize(sent)]
            query = sent.lower()



            meta = {'sent': sent}
            return {
                'query': query,
                'clip': torch.from_numpy(feats),
                'full': torch.tensor(is_full),
                'segment': all_segments,
                'mask': mask,
                # 'fine_gt_mask': fine_gt_mask,
                # 'coarse_gt_mask': coarse_gt_mask,
                # 'coarse_gt_box': coarse_gt_box,
                'sample_info': [vid, oid, frame_ids, sent],
                'index': index,
            }

        rnd = random.Random()

        vid, oid, frame_ids, anker, sent = self.videos[index]
        use_frame_ids, target_frame_idx = self.sample_frame_ids_base(frame_ids, anker, rnd, vid, oid)

        img_name = self.data_root + '/' + self.split + '/JPEGImages/' + vid + '/{}.jpg'.format(use_frame_ids[3])
        target_frame = io.imread(img_name)
        # get frames and masks
        frames, masks = [], []
        flip = False  # np.random.random() < 0.5 and self.train
        for frame_id in use_frame_ids:
            frm, msk = self.load_pair(vid, oid, frame_id, flip)
            # if self.jitter:
            #     frm, msk = self.random_jitter(frm, msk, self.size, rnd)
            # else:
            #     frm, msk = self.resize(frm, msk, self.size)
            frames.append(frm)
            masks.append(msk)

        # all_frames = np.stack(frames, axis=0)
        fine_gt_mask = np.stack(masks, axis=0)[3, :, :]  # target_idx
        if fine_gt_mask.sum() == 0:
            print(vid, oid, frame_id)
        if flip:
            fine_gt_mask = Image.fromarray(fine_gt_mask)
            fine_gt_mask = np.asarray(F.hflip(fine_gt_mask))
        # all_frames1 = []


        feat = np.asarray(frames).transpose((3, 0, 1, 2))



        query = sent.lower()

        return {
            'query': query,
            'clip': feat,
            'full': is_full,
            'fine_gt_mask': fine_gt_mask,
            'target_frame': target_frame,
            # 'coarse_gt_mask': coarse_gt_mask,
            # 'coarse_gt_box': coarse_gt_box,
            'sample_info': [vid, oid, frame_ids, anker, sent],
            'sample_information':[vid, oid, use_frame_ids[3], anker, sent],
            'index': index,
        }


class RefSlicRGB:
    def __init__(self, args):
        self.word2vec = KeyedVectors.load_word2vec_format(args['vocab_path'], binary=True)
        self.args = args
        self.max_num_words = args['max_num_words']
        # self.resolution = [8, 16, 32, 64, 128, 256]
        self.resolution = [10, 20, 40, 80, 160, 320]
        self.n_segments = {
            self.resolution[0]: 16,
            self.resolution[1]: 32,
            self.resolution[2]: 32,
            self.resolution[3]: 32,
            self.resolution[4]: 32,
            self.resolution[5]: 32,
        }

        # self._read_video_info()
        # self._read_dataset_samples()
        self.bert_embedding = BertEmbedding()
        self.train_set_ = RefSubset(self.word2vec, args, train=True)
        self.test_set_ = RefSubset(self.word2vec, args, train=False)

    @property
    def train_set(self):
        return self.train_set_

    @property
    def test_set(self):
        return self.test_set_

    def collate_fn(self, samples):
        bsz = len(samples)

        # clip
        # for sample in samples:
        #     a = sample['clip']
        #     print(a.dtype)

        # x = [sample['clip'] for sample in samples]
        # z = np.asarray(x)
        # y  = np.stack(z, 0)
        # y = torch.from_numpy(z)
        clip = torch.from_numpy(np.asarray([sample['clip'] for sample in samples])).float()

        full = torch.from_numpy(np.asarray([sample['full'] for sample in samples])).long()
        index = torch.from_numpy(np.asarray([sample['index'] for sample in samples])).long()
        # vclass = torch.from_numpy(np.asarray([sample['vclass'] for sample in samples])).long()
        result = self.bert_embedding([sample['query'] for sample in samples])
        query = [np.asarray(i[1]) for i in result]
        query_words = [[self.bert_embedding.vocab.token_to_idx[j] for j in i[0]] for i in result]

        # query
        query_len = []
        for i, sample in enumerate(query):
            query_len.append(min(len(sample), self.max_num_words))
        query1 = np.zeros([bsz, max(query_len), 768]).astype(np.float32)
        query_idx = np.zeros([bsz, max(query_len)]).astype(np.float32)
        for i, sample in enumerate(query):
            keep = min(sample.shape[0], query1.shape[1])
            query1[i, :keep] = sample[:keep]
            query_idx[i, :keep] = query_words[i][:keep]
        query_len = np.asarray(query_len)
        query, query_len = torch.from_numpy(query1).float(), torch.from_numpy(query_len).long()
        query_idx = torch.from_numpy(query_idx).long()

        # ori_coarse_gt_mask = [sample['coarse_gt_mask'] for sample in samples]
        # ori_coarse_gt_box = [sample['coarse_gt_box'] for sample in samples]
        ori_fine_gt_mask = [sample['fine_gt_mask'] for sample in samples]

        # coarse_gt_mask = {}
        fine_gt_mask = {}
        all_anch_mask = {}
        all_segments = {}
        mask = {}

        for r in self.resolution:
            fine_gt_mask[r], mask[r] = [], []
            all_anch_mask[r] = []
            all_segments[r] = []
            for i, sample in enumerate(samples):
                ori_h, ori_w = ori_fine_gt_mask[i].shape
                cur_h, cur_w = int(r / ori_w * ori_h), r
                # assert ori_h <= ori_w
                a, b, pad = resize_and_pad(sample['fine_gt_mask'] * 255, r, interpolation=cv2.INTER_AREA)
                fine_gt_mask[r].append(a > 50)
                mask[r].append(b)

                cur_frame = resize(sample['target_frame'], r)
                segment = slic(cur_frame, n_segments=self.n_segments[r] - 1) + 1
                # pad_config = [[pad, r - cur_frame.shape[0] - pad], [0, 0]]
                if r - cur_frame.shape[0] - pad < 0:
                    pad_config = [[0, 0], [pad, r - cur_frame.shape[1] - pad]]
                else:
                    pad_config = [[pad, r - cur_frame.shape[0] - pad], [0, 0]]
                segment = np.pad(segment, pad_config, mode='constant', constant_values=0)
                all_segments[r].append(segment)

            # coarse_gt_mask[r] = torch.from_numpy(np.asarray(coarse_gt_mask[r])).long()
            fine_gt_mask[r] = torch.from_numpy(np.asarray(fine_gt_mask[r])).long()
            mask[r] = torch.from_numpy(np.asarray(mask[r])).long()
            # all_anch_mask[r] = torch.from_numpy(np.asarray(all_anch_mask[r])).long()
            all_segments[r] = torch.from_numpy(np.asarray(all_segments[r])).long()
        # print(clip.size())

        return {
            'net_input': {
                'clip': clip,
                'query': query,
                'query_len': query_len,
                'fine_gt_mask': fine_gt_mask,
                # 'anch_mask': all_anch_mask,
                'segment': all_segments,
                # 'target_frame': torch.stack([sample['target_frame1'] for sample in samples], 0).float(),
                'mask': mask,
                'full': full,
                'index': index,
            },
            'target_frame': [sample['target_frame'] for sample in samples],
            # 'ori_coarse_gt_box': ori_coarse_gt_box,
            'ori_fine_gt_mask': ori_fine_gt_mask,
            'video_info': [sample['sample_info'] for sample in samples],
            'video_information':[sample['sample_information'] for sample in samples]
        }


class Refer_Youtube(data.Dataset):

    def __init__(self, data_root, split, N=8, size=(320, 320), max_skip=1, query_len=20, eval=False, jitter=True,
                 bert=False, scale=1.0):
        self.data_root = Path(data_root)

        split_type = split.split('_')[0]

        self.split = split_type

        self.N = N
        self.size = size
        self.max_skip = max_skip
        self.query_len = query_len
        self.eval = eval
        self.jitter = jitter
        self.bert = bert
        self.scale = scale

        self.max_frames = 36

        self.image_dir = self.data_root / split / 'JPEGImages'
        self.mask_dir = self.data_root / split / 'Annotations'

        self.set_meta_file()

    def set_meta_file(self):

        mymeta_path = self.data_root / self.split / 'mymeta.pkl'
        if mymeta_path.exists():
            with mymeta_path.open('rb') as f:
                self.videos = pickle.load(f)
        else:
            data = json.load(open('/home/user/data/rvos/' + self.split + '/meta_expressions.json'))

            self.videos = []
            for vid, objs in tqdm.tqdm(data['videos'].items(), desc='Data processing'):
                # vid = '0a2f2bd294'
                # objs = data['videos'][vid]
                if self.eval:
                    for obj_id, obj in objs['expressions'].items():
                        # oid = int(obj['obj_id'])

                        sents = obj['exp']
                        if len(sents) > 0:
                            self.videos.append([vid, objs['frames'], sents])


                else:
                    for obj_id, obj in objs['expressions'].items():
                        oid = int(obj['obj_id'])

                        sents = obj['exp']
                        if len(sents) == 0:
                            print("Not included (no sents): ", vid, oid)

                        anker = []
                        for frm in objs['frames']:
                            mask_name = self.data_root / self.split / 'Annotations' / vid / '{}.png'.format(frm)
                            mask = np.uint8(Image.open(mask_name).convert('P'))
                            mask = np.uint8(mask == oid)
                            if float(mask.sum()) / mask.size > np.square(0.02):
                                anker += [1]
                            else:
                                anker += [0]

                        if sum(anker) >= 3:
                            print("Add")
                            # for sent in sents:
                            self.videos.append([vid, oid, objs['frames'], anker, sents])
                        else:
                            print("Not included : ", vid, oid, anker)

            with mymeta_path.open('wb') as f:
                pickle.dump(self.videos, f, pickle.HIGHEST_PROTOCOL)

        len_videos = len(self.videos)
        if self.scale < 1.0:
            len_videos = int(len_videos * self.scale)
        self.videos = self.videos[:len_videos]

    def __len__(self):
        len_videos = len(self.videos)
        return len_videos

    def random_crop(self, frame, mask, size, rnd):

        # resize `frame` before cropping
        # resized frame should be large than `size` but shouldn't be too large
        min_scale = np.maximum(size[0] / np.float(frame.shape[0]), size[1] / np.float(frame.shape[1]))
        scale = np.maximum(rnd.uniform(min_scale + 0.01, 1.875 * min_scale), min_scale + 0.01)

        dsize = (np.int(frame.shape[1] * scale), np.int(frame.shape[0] * scale))
        trans_frame = cv2.resize(frame, dsize=dsize, interpolation=cv2.INTER_LINEAR)
        trans_mask = cv2.resize(mask, dsize=dsize, interpolation=cv2.INTER_NEAREST)

        ## try to crop patch that contains object area if possible, otherwise just return
        np_in1 = np.sum(trans_mask)

        for _ in range(100):
            cr_y = rnd.randint(0, trans_mask.shape[0] - size[0])
            cr_x = rnd.randint(0, trans_mask.shape[1] - size[1])
            crop_mask = trans_mask[cr_y:cr_y + size[0], cr_x:cr_x + size[1]]
            crop_frame = trans_frame[cr_y:cr_y + size[0], cr_x:cr_x + size[1], :]

            nnz_crop_mask = np.sum(crop_mask)
            break

        return crop_frame, crop_mask

    def random_jitter(self, frame, mask, size, rnd):

        scale = rnd.uniform(1, 1.1)
        dsize = (int(size[0] * scale), int(size[1] * scale))

        trans_frame = cv2.resize(frame, dsize=dsize, interpolation=cv2.INTER_LINEAR)
        trans_mask = cv2.resize(mask, dsize=dsize, interpolation=cv2.INTER_NEAREST)

        np_in1 = np.sum(trans_mask)

        crop_frame = None
        for _ in range(100):
            cr_y = rnd.randint(0, trans_mask.shape[0] - size[0])
            cr_x = rnd.randint(0, trans_mask.shape[1] - size[1])
            crop_mask = trans_mask[cr_y:cr_y + size[0], cr_x:cr_x + size[1]]
            crop_frame = trans_frame[cr_y:cr_y + size[0], cr_x:cr_x + size[1], :]
            if np.sum(crop_mask) > 0.8 * np_in1:
                break

        if crop_frame is None:
            return self.random_jitter(frame, mask, size, rnd)

        return crop_frame, crop_mask

    def resize(self, frame, mask, size):
        scale = np.maximum(size[0] / np.float(frame.shape[0]), size[1] / np.float(frame.shape[1]))
        dsize = (np.int(frame.shape[1] * scale), np.int(frame.shape[0] * scale))
        size = (size[0], size[1])
        resize_frame = cv2.resize(frame, dsize=size, interpolation=cv2.INTER_LINEAR)
        if mask is not None:
            resize_mask = cv2.resize(mask, dsize=size, interpolation=cv2.INTER_NEAREST)
            return resize_frame, resize_mask
        else:
            return resize_frame

    def load_pair(self, vid, oid, fid):
        img_name = self.data_root / self.split / 'JPEGImages' / vid / '{}.jpg'.format(fid)
        # mask_name = self.data_root / self.split / 'Annotations' / vid / '{}.png'.format(fid)

        frame = np.float32(Image.open(img_name).convert('RGB')) / 255.
        # mask = np.uint8(Image.open(mask_name).convert('P'))
        # mask = np.uint8(mask == oid)
        return frame  # , mask

    def load_pairs(self, vid, frame_ids):
        frames, masks = [], []
        for frame_id in frame_ids:
            # obj_id = 1
            frame = self.load_pair(vid, 1, frame_id)
            frame = self.resize(frame, None, self.size)
            frames.append(frame)
            # masks.append(mask)

        N_frames = np.stack(frames, axis=0)
        # N_masks = np.stack(masks, axis=0)

        Fs = torch.from_numpy(np.transpose(N_frames, (0, 3, 1, 2)).copy()).float()
        # Ms = torch.from_numpy(N_masks.copy()).float()
        return Fs  # , Ms

    def sample_frame_ids_base(self, frame_ids, anker, rnd, vid, oid):

        n_images = len(frame_ids)
        tt = 0
        while True:
            sample_skips = [rnd.randint(1, self.max_skip) for _ in range(1, self.N)]
            if sum(sample_skips) < n_images:
                break
            if tt > 100:
                sample_skips = [1] * (self.N - 1)
                break
            tt = tt + 1

        use_ids = [None] * self.N
        # start index
        anker_nnz = [i for i, e in enumerate(anker) if e != 0]
        n_skip = sum(sample_skips)

        if True:  # always forward
            anker_idx = [i for i in anker_nnz if i + n_skip < n_images]
            if len(anker_idx) > 0:
                use_ids[0] = anker_idx[rnd.randint(0, len(anker_idx) - 1)]
            else:
                use_ids[0] = anker_nnz[0]

        for i in range(1, self.N):
            use_ids[i] = (use_ids[i - 1] + sample_skips[i - 1]) % n_images

        use_frame_ids = [frame_ids[i] for i in use_ids]
        return use_frame_ids

    def __getitem__(self, index):

        if self.eval:
            vid, frame_ids, sent = self.videos[index]

            Fs = self.load_pairs(vid, frame_ids)

            num_frames = len(Fs)
            if num_frames < self.max_frames:
                pad_frames = self.max_frames - num_frames
                Fs = F.pad(Fs, (0, 0) * 3 + (0, pad_frames))
                # Ms = F.pad(Ms, (0, 0) * 2 + (0, pad_frames))

            words = sent
            # ann_id = '{}_{}'.format(vid, oid)

            meta = {'sent': sent}
            return Fs, words, num_frames, meta

        rnd = random.Random()

        vid, oid, frame_ids, anker, sent = self.videos[index]
        use_frame_ids = self.sample_frame_ids_base(frame_ids, anker, rnd, vid, oid)

        # get frames and masks
        frames, masks = [], []
        for frame_id in use_frame_ids:
            frm, msk = self.load_pair(vid, oid, frame_id)
            if self.jitter:
                frm, msk = self.random_jitter(frm, msk, self.size, rnd)
            else:
                frm, msk = self.resize(frm, msk, self.size)
            frames.append(frm)
            masks.append(msk)

        frames = np.stack(frames, axis=0)
        masks = np.stack(masks, axis=0)

        Fs = torch.from_numpy(np.transpose(frames, (0, 3, 1, 2)).copy()).float()
        Ms = torch.from_numpy(masks.copy()).float()
        words = sent
        ann_id = '{}_{}'.format(vid, oid)

        return Fs, Ms, words, ann_id

    # def tokenize_sent(self, sent):
    #     return self.corpus.tokenize(sent, self.query_len)
    #
    # def untokenize_word_vector(self, words):
    #     return self.corpus.untokenize(words)


if __name__ == '__main__':
    # ---original---
    # split_type = 'train'
    # N= 8
    # trainset = Refer_Youtube(data_root='/home/user/data/rvos', split='valid', eval=True)
    # dataLoader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True, num_workers=8,
    #                                          drop_last=True, pin_memory=False)
    # for bid, batch in enumerate(dataLoader, 1):
    #
    #     print("??")
    #     # for k, v in batch['net_input'].items():
    #     #     print(k, v.size())
    #     #exit(0)

    #
    # # ---2---
    args = {
        
    }
    dataset = RefSlicRGB(args)
    # dataset.train_set[66]
    # exit(0)
    loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
                        pin_memory=False, collate_fn=dataset.collate_fn)
    print(len(loader))
    for batch in loader:
        print(batch["net_input"]["query"])
        # for k, v in batch['net_input'].items():
        #     print(k, v.size())
        # exit(0)