import csv
import os
import random


import cv2
import h5py
import numpy as np
import torch
from PIL import Image
from bert_embedding import BertEmbedding
from gensim.models import KeyedVectors
from nltk.corpus import stopwords, wordnet
from skimage import io
from skimage.segmentation import slic
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import torchvision.transforms.functional as F

from datasets.utils import resize_and_pad, resize


class A2DSubset(Dataset):
    def __init__(self, videos, samples, word2vec, args, transforms=None, train=False):
        self.videos = videos
        self.samples = samples
        self.word2vec = word2vec
        self.args = args
        self.trans = transforms

        self.is_full = np.zeros(len(self.samples), dtype=np.int64)
        np.random.seed(88)
        self.full_videos = list(videos.keys())
        np.random.shuffle(self.full_videos)
        self.full_videos = self.full_videos[:int(0.5 * len(self.full_videos))]
        self.train = train

        # self.idx = set(list(np.random.permutation(len(self.samples))
        #                     [:int(0.3 * len(self.samples))]))

        # if not self.train:
        #     self.new_samples = []
        #     for s in self.samples:
        #         if s[0] == 'sULsOQDaDw0':
        #             self.new_samples.append(s)
        #     self.samples = self.new_samples

        # if self.train:
        #     self.new_samples = []
        #     for s in self.samples:
        #         if s[0] in self.full_videos:
        #             self.new_samples.append(s)
        #     self.samples = self.new_samples
        self.rgb = {}

        with open('/home/user/code/iccv-2021-seg/color') as fp:
            for line in fp.readlines():
                tmp = line.strip().split()
                self.rgb[int(tmp[1])] = np.asarray([int(tmp[3]), int(tmp[4]), int(tmp[5])])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        video_id, instance_id, frame_idx, query = self.samples[index]
        flip = np.random.random() < 0.5 and self.train
        query = query.lower()
        frame_idx = int(frame_idx)
        if flip:
            query = query.replace('left', 'right')
            query = query.replace('right', 'left')
        # is_full = 1 if video_id in self.full_videos else 0
        is_full = 1
        # is_full = 1 if index in self.idx else 0

        i3d_path = os.path.join('/home/user/data/A2D/i3d-rgb-MaxPool3d_4a_3x3/%s/%d.npy' % (video_id, frame_idx))
        h5_path = os.path.join('/home/user/data/A2D/a2d_annotation_with_instances', video_id,
                               '%05d.h5' % (frame_idx + 1))
        if not os.path.exists(h5_path):
            h5_path = os.path.join('/home/user/data/A2D/a2d_annotation_with_instances', video_id,
                                   '%05d.h5' % (24 + 1))
        frame_path = os.path.join('/home/user/data/A2D/Release/pngs320H/', video_id)
        # frame_path = os.path.join('/home/user/data/A2D/rgb-frame/%s/' % video_id)

        frames = list(map(lambda x: os.path.join(frame_path, x),
                          sorted(os.listdir(frame_path))))
        # print(len(frames), self.videos['num_frames'])

        assert len(frames) == self.videos[video_id]['num_frames']

        try:
            target_frame = io.imread(frames[frame_idx])
        except BaseException as e:
            print(video_id, frame_idx, frame_path)
            exit(0)
        mean_pix = np.asarray([120.39586422, 115.59361427, 104.54012653]) / 255.0
        std_pix = np.asarray([70.68188272, 68.27635443, 72.54505529]) / 255.0

        limit_size = 320

        target_frame1, _, _ = resize_and_pad(target_frame, limit_size=limit_size)
        target_frame1 = transforms.Compose([
            lambda x: np.asarray(x),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean_pix, std=std_pix)
        ])(target_frame1)

        step = 1
        all_frames = [
            frame_idx - 3,
            frame_idx - 2,
            frame_idx - 1,
            frame_idx,
            frame_idx,
            frame_idx + 1,
            frame_idx + 2,
            frame_idx + 3,
        ]
        # all_frames = [i for i in range(frame_idx - 4 * step, frame_idx + 4 * step, step)]
        for i in range(len(all_frames)):
            if all_frames[i] < 0:
                all_frames[i] = 0
            elif all_frames[i] >= len(frames):
                all_frames[i] = len(frames) - 1
        all_frames = np.asarray(frames)[all_frames]
        all_frames1 = []

        # flow_feat = []
        prev = None

        def compute_TVL1(prev, curr, bound=15):
            TVL1 = cv2.optflow.DualTVL1OpticalFlow_create()
            flow = TVL1.calc(prev, curr, None)
            flow = np.clip(flow, -20, 20)  # default values are +20 and -20
            assert flow.dtype == np.float32
            flow = (flow + bound) * (255.0 / (2 * bound))
            flow = np.round(flow).astype(int)
            flow[flow >= 255] = 255
            flow[flow <= 0] = 0
            return flow

        for i in all_frames:
            # print(i)
            #  frame = cv2.imread(i)
            frame = io.imread(i)
            # image = frame[:, :, [2, 1, 0]]
            image = frame

            image = Image.fromarray(image)
            if flip:
                image = np.asarray(F.hflip(image))
            else:
                image = np.asarray(image)

            # if self.trans is not None:
            #     image = self.trans(image)
            # else:
            #     image = np.asarray(image)
            # print(image.shape)
            image, valid_mask, _ = resize_and_pad(image, limit_size=limit_size)

            # if prev is None:
            #     prev = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            # curr = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            # image_flow = compute_TVL1(prev, curr)
            # image_flow = np.divide(2 * (image_flow - image_flow.min()), image_flow.max() - image_flow.min() + 1e-10) - 1
            # flow_feat.append(image_flow)
            # prev = curr

            # image_norm = 2 * ((image - image.min() + 1e-10) / (image.max() - image.min() + 1e-10)) - 1
            # image_norm = (image - mean_pix[np.newaxis, np.newaxis, :]) / std_pix[np.newaxis, np.newaxis, :]
            image_norm = 2 * (image / 255.0) - 1
            image_norm = image_norm * valid_mask[:, :, np.newaxis]
            # image_norm = image / 255.0
            all_frames1.append(image_norm)
        # clip feature
        # feat = np.load(i3d_path)
        feat = np.asarray(all_frames1).transpose((3, 0, 1, 2))

        # flow_feat = np.asarray(flow_feat).transpose((3, 0, 1, 2))

        def get_wordnet_pos(tag):
            if tag.startswith('J'):
                return wordnet.ADJ
            elif tag.startswith('V'):
                return wordnet.VERB
            elif tag.startswith('N'):
                return wordnet.NOUN
            elif tag.startswith('R'):
                return wordnet.ADV
            else:
                return None

        # words feature
        stopset = set(stopwords.words('english'))
        # query = [word.lower() for word in word_tokenize(query)]
        # # tagged_sent = nltk.pos_tag(query)
        # # # query = [word for word in query if word not in stopset]
        # # query1 = []
        # # wnl = nltk.WordNetLemmatizer()
        # # for tag in tagged_sent:
        # #     wordnet_pos = get_wordnet_pos(tag[1]) or wordnet.NOUN
        # #     query1.append(wnl.lemmatize(tag[0], pos=wordnet_pos))
        # # query = query1
        # # print(query)
        # # exit(0)
        #
        # query_idx = [self.word2vec.vocab[word].index + 1 for word in query if word in self.word2vec]
        # query = [self.word2vec[word] for word in query if word in self.word2vec]
        #
        # query = np.asarray(query)
        # query_idx = np.asarray(query_idx)

        # fine-grained mask
        with h5py.File(h5_path, mode='r') as fp:
            instance = np.asarray(fp['instance'])
            all_masks = np.asarray(fp['reMask'])
            if len(all_masks.shape) == 3 and instance.shape[0] != all_masks.shape[0]:
                print(video_id, frame_idx + 1, instance.shape, all_masks.shape)

            all_boxes = np.asarray(fp['reBBox']).transpose([1, 0])  # [w_min, h_min, w_max, h_max]
            all_ids = np.asarray(fp['id'])
            # if video_id == 'EadxBPmQvtg' and frame_idx == 24:
            #     instance = instance[:-1]
            assert len(all_masks.shape) == 2 or len(all_masks.shape) == 3
            if len(all_masks.shape) == 2:
                mask = all_masks[np.newaxis]
                class_id = int(all_ids[0][0])
                coarse_gt_box = all_boxes[0]
            else:
                instance_id = int(instance_id)
                idx = np.where(instance == instance_id)[0][0]
                # if instance.shape[0] < all_masks.shape[0]:
                #     print(video_id, instance_id, query, instance)
                # if instance.shape[0] < all_masks.shape[0]:
                #     # print(instance, all_masks.shape)
                #     tail = instance.shape[0] - all_masks.shape[0]
                #     all_masks = all_masks[:tail]
                #     all_boxes = all_boxes[:tail]
                #     all_ids = all_ids[:, :tail]
                #     # print(video_id, videos[video_id])
                # assert instance.shape[0] == all_masks.shape[0]
                # assert all_masks.shape[0] == all_boxes.shape[0]
                # idx = instance == int(instance_id)
                mask = all_masks[idx]
                coarse_gt_box = all_boxes[idx]
                class_id = int(all_ids[0][idx])
                mask = mask[np.newaxis]
                # print('matching:', idx.sum())
                # print(instance[instance_id], all_ids.shape, all_ids[instance[instance_id]])

                # if int(instance_id) >= instance.shape[0]:
                #     print(video_id, instance_id, instance)
                # mask = all_masks[instance[instance_id]]
                # coarse_gt_box = all_boxes[instance[instance_id]]
                # class_id = int(all_ids[0][instance[instance_id]])
                # mask = mask[np.newaxis]

                # print(video_id, query, class_id, mask.shape)

            #
            # print(class_id, class_name)

            assert len(mask.shape) == 3
            assert mask.shape[0] > 0

            fine_gt_mask = np.transpose(np.asarray(mask), (0, 2, 1))[0]
            if flip:
                fine_gt_mask = Image.fromarray(fine_gt_mask)
                fine_gt_mask = np.asarray(F.hflip(fine_gt_mask))
            # w_min, h_min, w_max, h_max = np.asarray(coarse_gt_box, dtype=np.int32)
            # coarse_gt_mask = np.zeros_like(fine_gt_mask)
            # coarse_gt_mask[h_min:h_max + 1, w_min:w_max + 1] = 1
            # print(coarse_gt_box, target_frame.shape, fine_gt_mask.shape)
            # w_min, h_min, w_max, h_max = np.asarray(coarse_gt_box, dtype=np.int32)
            # cv2.line(target_frame, (w_min, h_min), (w_min, h_max), color=(0, 0, 255))
            # cv2.line(target_frame, (w_min, h_min), (w_max, h_min), color=(0, 0, 255))
            # cv2.line(target_frame, (w_min, h_max), (w_max, h_max), color=(0, 0, 255))
            # cv2.line(target_frame, (w_max, h_min), (w_max, h_max), color=(0, 0, 255))
            # cv2.imwrite('../debug/test.png', target_frame * coarse_gt_mask[:, :, np.newaxis])
            # exit(0)
        return {
            'query': query,
            # 'query_idx': query_idx,
            'clip': feat,
            # 'clip_flow': flow_feat,
            'target_frame': target_frame,
            'fine_gt_mask': fine_gt_mask,
            'target_frame1': target_frame1,
            'full': is_full,
            # 'coarse_gt_mask': coarse_gt_mask,
            'coarse_gt_box': coarse_gt_box,
            'sample_info': [video_id, instance_id, frame_idx, self.samples[index][-1], self.rgb[class_id]]
        }


class A2DSlicRGB:
    def __init__(self, args):
        self.word2vec = KeyedVectors.load_word2vec_format(args['vocab_path'], binary=True)
        # self.word2vec.save()

        self.args = args
        self.col_path = os.path.join(args['annotation_path'], 'col')  # rgb
        self.mat_path = os.path.join(args['annotation_path'], 'mat')  # matrix
        self.max_num_words = args['max_num_words']
        self.resolution = [8, 16, 32, 64, 128, 256]
        self.resolution = [10, 20, 40, 80, 160, 320]
        self.n_segments = {
            self.resolution[0]: 16,
            self.resolution[1]: 32,
            self.resolution[2]: 32,
            self.resolution[3]: 32,
            self.resolution[4]: 32,
            self.resolution[5]: 32,
        }

        self.bert_embedding = BertEmbedding()
        self._read_video_info()
        self._read_dataset_samples()
        self.train_set_ = A2DSubset(self.train_videos, self.train_samples, self.word2vec, args,
                                    transforms=transforms.Compose([
                                        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                                        transforms.RandomHorizontalFlip(),
                                        lambda x: np.asarray(x),
                                    ]), train=True)
        self.test_set_ = A2DSubset(self.test_videos, self.test_samples, self.word2vec, args)

    def _read_video_info(self):
        self.train_videos, self.test_videos = {}, {}
        with open(self.args['videoset_path'], newline='') as fp:
            reader = csv.reader(fp, delimiter=',')
            for row in reader:
                frame_idx = list(map(lambda x: int(x[:-4]) - 1,
                                     os.listdir(os.path.join(self.col_path, row[0]))))
                frame_idx = sorted(frame_idx)
                # print(frame_idx)
                # exit(0)
                video_info = {
                    'label': int(row[1]),
                    'timestamps': [row[2], row[3]],
                    'size': [int(row[4]), int(row[5])],  # [height, width]
                    'num_frames': int(row[6]),
                    'num_annotations': int(row[7]),
                    'frame_idx': frame_idx,
                }
                if int(row[8]) == 0:
                    self.train_videos[row[0]] = video_info
                else:
                    self.test_videos[row[0]] = video_info

    def _read_dataset_samples(self):
        self.train_samples, self.test_samples = [], []
        self.train_videos_set = set()
        self.test_videos_set = set()
        self.all_query = set()
        with open(self.args['sample_path'], newline='') as fp:
            reader = csv.DictReader(fp)
            from collections import defaultdict
            video2frame = defaultdict(list)
            rows = []
            for row in reader:
                rows.append(row)
                video2frame[(row['video_id'], row['query'])].append(row['frame_idx'])
            for row in rows:
                if row['video_id'] in self.train_videos:
                    self.train_samples.append([row['video_id'], row['instance_id'], row['frame_idx'], row['query']])
                    self.train_videos_set.add(row['video_id'])
                else:
                    l = video2frame[(row['video_id'], row['query'])]
                    # if l[len(l) >> 1] != row['frame_idx']:
                    #     continue
                    self.test_samples.append([row['video_id'], row['instance_id'], row['frame_idx'], row['query']])
                    self.test_videos_set.add(row['video_id'])
                    # self.test_samples.append([row['video_id'], row['instance_id'], 36, row['query']])
                    # self.test_samples.append([row['video_id'], row['instance_id'], 61, row['query']])
                self.all_query.add((row['video_id'], row['query']))
        print('number of sentences: {}'.format(len(self.all_query)))
        print('videos for training: {}, videos for testing: {}'.format(len(self.train_videos_set),
                                                                       len(self.test_videos_set)))
        print(
            'samples for training: {}, samples for testing: {}'.format(len(self.train_samples), len(self.test_samples)))
        # exit(0)

    @property
    def train_set(self):
        return self.train_set_

    @property
    def test_set(self):
        return self.test_set_

    def collate_fn(self, samples):
        bsz = len(samples)

        # clip
        clip = torch.from_numpy(np.asarray([sample['clip'] for sample in samples])).float()
        # clip_flow = torch.from_numpy(np.asarray([sample['clip_flow'] for sample in samples])).float()

        full = torch.from_numpy(np.asarray([sample['full'] for sample in samples])).long()

        result = self.bert_embedding([sample['query'] for sample in samples])
        query = [np.asarray(i[1]) for i in result]

        # query
        # query_len = []
        # for i, sample in enumerate(samples):
        #     query_len.append(min(len(sample['query']), self.max_num_words))
        # query = np.zeros([bsz, max(query_len), 300]).astype(np.float32)
        # query_idx = np.zeros([bsz, max(query_len)]).astype(np.float32)
        # for i, sample in enumerate(samples):
        #     keep = min(len(sample['query']), query.shape[1])
        #     query[i, :keep] = sample['query'][:keep]
        #     query_idx[i, :keep] = sample['query_idx'][:keep]
        query_len = []
        for i, sample in enumerate(query):
            query_len.append(min(len(sample), self.max_num_words))
        query1 = np.zeros([bsz, max(query_len), 768]).astype(np.float32)
        # query_idx = np.zeros([bsz, max(query_len)]).astype(np.float32)
        for i, sample in enumerate(query):
            keep = min(sample.shape[0], query1.shape[1])
            query1[i, :keep] = sample[:keep]
            # query_idx[i, :keep] = sample['query_idx'][:keep]
        query_len = np.asarray(query_len)
        query, query_len = torch.from_numpy(query1).float(), torch.from_numpy(query_len).long()
        # print(query.size())
        # print(query_len)
        # exit(0)
        # query_idx = torch.from_numpy(query_idx).long()

        # ori_coarse_gt_mask = [sample['coarse_gt_mask'] for sample in samples]
        ori_coarse_gt_box = [sample['coarse_gt_box'] for sample in samples]
        ori_fine_gt_mask = [sample['fine_gt_mask'] for sample in samples]

        coarse_gt_mask = {}
        fine_gt_mask = {}
        all_anch_mask = {}
        all_segments = {}
        mask = {}

        for r in self.resolution:
            coarse_gt_mask[r], fine_gt_mask[r], mask[r] = [], [], []
            all_anch_mask[r] = []
            all_segments[r] = []
            for i, sample in enumerate(samples):
                ori_h, ori_w = ori_fine_gt_mask[i].shape
                cur_h, cur_w = int(r / ori_w * ori_h), r
                assert ori_h <= ori_w
                a, b, pad = resize_and_pad(sample['fine_gt_mask'] * 255, r, interpolation=cv2.INTER_LINEAR)
                fine_gt_mask[r].append(a > 100)
                mask[r].append(b)

                cur_frame = resize(sample['target_frame'], r)
                segment = slic(cur_frame, n_segments=self.n_segments[r] - 1) + 1
                pad_config = [[pad, r - cur_frame.shape[0] - pad], [0, 0]]
                segment = np.pad(segment, pad_config, mode='constant', constant_values=0)
                all_segments[r].append(segment)

                w_min, h_min, w_max, h_max = np.asarray(ori_coarse_gt_box[i] * r / ori_w, dtype=np.int32)
                cmask = np.zeros_like(fine_gt_mask[r][-1])
                cmask[pad + h_min:pad + h_max + 1, w_min:w_max + 1] = 1
                coarse_gt_mask[r].append(cmask)

                anch_mask = np.zeros_like(fine_gt_mask[r][-1], dtype=np.int64)
                height = h_max - h_min + 1
                width = w_max - w_min + 1
                anch_h, anch_w = height // 4, width // 4
                anch_h = 1 if anch_h == 0 else anch_h
                anch_w = 1 if anch_w == 0 else anch_w
                assert anch_h > 0 and anch_w > 0

                anch_h, anch_w = 1, 1

                if random.random() < 0.5:
                    a, b = np.random.choice(range(h_min, h_max + 1), size=2, replace=(h_min == h_max))
                    h1, h2 = a, a + 1
                    anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                    h1, h2 = b, b + 1
                    anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                    # print(a, b)
                    # exit(0)
                else:
                    a, b = np.random.choice(range(w_min, w_max + 1), size=2, replace=(w_min == w_max))
                    w1, w2 = a, a + 1
                    anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                    w1, w2 = b, b + 1
                    anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # anch_h < anch_w
                # if random.random() < 0.5:
                #     h1, h2 = (h_max + h_min + 1) // 2, (h_max + h_min + 1) // 2 + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                #     h1, h2 = h_min, h_min + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                # else:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # w_max - 2 * anch_w
                # if w_max - 2 * anch_w < w_min or True:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2
                #     # try:
                #     #     assert w_max - 2 * anch_w >= w_min
                #     # except AssertionError:
                #     #     print(w_min, w_max)
                #     #     exit(0)
                # else:
                #     w1 = np.random.randint(w_min, w_max - 2 * anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #
                #     w1 = np.random.randint(w2, w_max - anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1

                h1, h2 = h_min - anch_h if h_min >= anch_h else 0, h_min
                if random.random() < 0.5 and h1 >= 1:
                    h1, h2 = h1 - 1, h2 - 1
                anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 3

                h1, h2 = h_max + 1, h_max + 1 + anch_h if h_max + 1 + anch_h <= cur_h else cur_h
                if random.random() < 0.5 and h2 <= cur_h - 1:
                    h1, h2 = h1 + 1, h2 + 1
                anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 4

                w1, w2 = w_min - anch_w if w_min >= anch_w else 0, w_min
                if random.random() < 0.5 and w1 >= 1:
                    w1, w2 = w1 - 1, w2 - 1
                anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 5

                w1, w2 = w_max + 1, w_max + 1 + anch_w if w_max + 1 + anch_w <= cur_w else cur_w
                if random.random() < 0.5 and w2 <= cur_w - 1:
                    w1, w2 = w1 + 1, w2 + 1
                anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 6

                all_anch_mask[r].append(anch_mask)
                # print(r, (w_min, h_min, w_max, h_max))
                # for i in anch_mask:
                #     print(i)
                # exit(0)

            coarse_gt_mask[r] = torch.from_numpy(np.asarray(coarse_gt_mask[r])).long()
            fine_gt_mask[r] = torch.from_numpy(np.asarray(fine_gt_mask[r])).float()
            mask[r] = torch.from_numpy(np.asarray(mask[r])).long()
            all_anch_mask[r] = torch.from_numpy(np.asarray(all_anch_mask[r])).long()
            all_segments[r] = torch.from_numpy(np.asarray(all_segments[r])).long()
        # print(clip.size())

        return {
            'net_input': {
                'clip': clip,
                # 'clip_flow': clip_flow,
                'query': query,
                # 'query_idx': query_idx,
                'query_len': query_len,
                'coarse_gt_mask': coarse_gt_mask,
                'fine_gt_mask': fine_gt_mask,
                'anch_mask': all_anch_mask,
                'segment': all_segments,
                'target_frame': torch.stack([sample['target_frame1'] for sample in samples], 0).float(),
                'mask': mask,
                'full': full
            },
            'target_frame': [sample['target_frame'] for sample in samples],
            'ori_coarse_gt_box': ori_coarse_gt_box,
            'ori_fine_gt_mask': ori_fine_gt_mask,
            'video_info': [sample['sample_info'] for sample in samples]
        }


if __name__ == '__main__':
    args = {
        "videoset_path": "/home/user/data/A2D/Release/videoset.csv",
        "annotation_path": "/home/user/data/A2D/Release/Annotations",
        "vocab_path": "/home/user/code/mm-2020/data/glove_a2d.bin",
        "sample_path": "/home/user/data/A2D/a2d_annotation2.txt",
        "max_num_words": 20,
    }
    dataset = A2DSlicRGB(args)
    dataset.train_set[66]
    # exit(0)
    loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
                        pin_memory=True, collate_fn=dataset.collate_fn)
    for batch in loader:
        # for k, v in batch['net_input'].items():
        #     print(k, v.size())
        exit(0)
