import csv
import os
import random

import cv2
import h5py
import scipy.io as scio
import numpy as np
import torch
from gensim.models import KeyedVectors
from nltk import word_tokenize
from torch.utils.data import Dataset, DataLoader

from datasets.utils import resize_and_pad


class JHMDBSubset(Dataset):
    def __init__(self, videos, samples, word2vec, args):
        self.videos = videos
        self.samples = samples
        self.word2vec = word2vec
        self.args = args

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        video_id, frame_idx = self.samples[index]
        frame_idx = int(frame_idx)
        query = self.videos[video_id]['query']
        class_ = self.videos[video_id]['class']

        i3d_path = os.path.join('/home/user/data/J-HMDB/i3d-rgb-MaxPool3d_4a_3x3/%s/%d.npy' % (video_id, frame_idx))
        frame_path = os.path.join('/home/user/data/J-HMDB/Rename_Images/%s/%s' %
                                      (class_, video_id))
        frames = list(map(lambda x: os.path.join(frame_path, x),
                          sorted(os.listdir(frame_path), key=lambda x: int(x[:-4]))))
        target_frame = cv2.imread(frames[frame_idx])

        # clip feature
        feat = np.load(i3d_path)

        # words feature
        query = [word.lower() for word in word_tokenize(query)]
        query = [self.word2vec[word] for word in query if word in self.word2vec]
        query = np.asarray(query)

        # fine-grained mask
        gt_path = '/home/user/data/J-HMDB_used/puppet_mask/%s/%s/puppet_mask.mat' % (class_, video_id)
        fine_gt_mask = scio.loadmat(gt_path)['part_mask']
        try:
            assert fine_gt_mask.shape[2] == len(frames)
        except AssertionError:
            if frame_idx >= fine_gt_mask.shape[2]:
                frame_idx = fine_gt_mask.shape[2] - 1
            print(class_, video_id, fine_gt_mask.shape, len(frames))
            # print(frames)
            exit(0)
        fine_gt_mask = fine_gt_mask[:, :, frame_idx]

        return {
            'query': query,
            'clip': feat,
            'fine_gt_mask': fine_gt_mask,
            # 'coarse_gt_mask': coarse_gt_mask,
            # 'coarse_gt_box': coarse_gt_box,
            'sample_info': [video_id, class_, frame_idx, query]
        }


class JHMDB:
    def __init__(self, args):
        self.word2vec = KeyedVectors.load_word2vec_format(args['vocab_path'], binary=True)
        self.args = args
        self.max_num_words = args['max_num_words']
        self.resoluation = [32, 128]

        self._read_video_info()
        self._read_dataset_samples()
        self.train_set_ = JHMDBSubset(self.video_info, self.train_samples, self.word2vec, args)
        self.test_set_ = JHMDBSubset(self.video_info, self.test_samples, self.word2vec, args)

    def _read_video_info(self):
        import json
        with open('/home/user/data/J-HMDB/video_info.json', 'r') as fp:
            self.video_info = json.load(fp)

    def _read_dataset_samples(self):
        self.split_path = '/home/user/data/J-HMDB/splits/'
        self.train_samples, self.test_samples = [], []

        for split in os.listdir(self.split_path):
            with open(os.path.join(self.split_path, split)) as fp:
                for line in fp:
                    video, flag = line.split()
                    if flag == '1':
                        for i in range(len(self.video_info[video[:-4]]['frames'])):
                            self.train_samples.append((video[:-4], i))
                    else:
                        idx = np.linspace(0, len(self.video_info[video[:-4]]['frames']) - 1, num=3, endpoint=True)
                        idx = idx.astype(np.int64)
                        for i in idx:
                            self.test_samples.append((video[:-4], i))

        print('samples for training: {}, samples for testing: {}'.format(len(self.train_samples), len(self.test_samples)))

    @property
    def train_set(self):
        return self.train_set_

    @property
    def test_set(self):
        return self.test_set_

    def collate_fn(self, samples):
        bsz = len(samples)

        # clip
        clip = torch.from_numpy(np.asarray([sample['clip'] for sample in samples])).float()

        # query
        query_len = []
        for i, sample in enumerate(samples):
            query_len.append(min(len(sample['query']), self.max_num_words))
        query = np.zeros([bsz, max(query_len), 300]).astype(np.float32)
        for i, sample in enumerate(samples):
            keep = min(len(sample['query']), query.shape[1])
            query[i, :keep] = sample['query'][:keep]
        query_len = np.asarray(query_len)
        query, query_len = torch.from_numpy(query).float(), torch.from_numpy(query_len).long()

        # ori_coarse_gt_mask = [sample['coarse_gt_mask'] for sample in samples]
        # ori_coarse_gt_box = [sample['coarse_gt_box'] for sample in samples]
        ori_fine_gt_mask = [sample['fine_gt_mask'] for sample in samples]

        coarse_gt_mask = {}
        fine_gt_mask = {}
        all_anch_mask = {}
        mask = {}

        for r in self.resoluation:
            coarse_gt_mask[r], fine_gt_mask[r], mask[r] = [], [], []
            all_anch_mask[r] = []
            for i, sample in enumerate(samples):
                ori_h, ori_w = ori_fine_gt_mask[i].shape
                cur_h, cur_w = int(r / ori_w * ori_h), r
                assert ori_h <= ori_w
                a, b, pad = resize_and_pad(sample['fine_gt_mask'] * 255, r, interpolation=cv2.INTER_AREA)
                fine_gt_mask[r].append(a > 125)
                mask[r].append(b)

                # w_min, h_min, w_max, h_max = np.asarray(ori_coarse_gt_box[i] * r / ori_w, dtype=np.int32)
                # cmask = np.zeros_like(fine_gt_mask[r][-1])
                # cmask[pad + h_min:pad + h_max + 1, w_min:w_max + 1] = 1
                # coarse_gt_mask[r].append(cmask)
                #
                # anch_mask = np.zeros_like(fine_gt_mask[r][-1], dtype=np.int64)
                # height = h_max - h_min + 1
                # width = w_max - w_min + 1
                # anch_h, anch_w = height // 4, width // 4
                # anch_h = 1 if anch_h == 0 else anch_h
                # anch_w = 1 if anch_w == 0 else anch_w
                # assert anch_h > 0 and anch_w > 0
                #
                # anch_h, anch_w = 1, 1
                #
                # if random.random() < 0.5:
                #     a, b = np.random.choice(range(h_min, h_max + 1), size=2, replace=(h_min == h_max))
                #     h1, h2 = a, a + 1
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                #     h1, h2 = b, b + 1
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                #     # print(a, b)
                #     # exit(0)
                # else:
                #     a, b = np.random.choice(range(w_min, w_max + 1), size=2, replace=(w_min == w_max))
                #     w1, w2 = a, a + 1
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = b, b + 1
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # anch_h < anch_w
                # if random.random() < 0.5:
                #     h1, h2 = (h_max + h_min + 1) // 2, (h_max + h_min + 1) // 2 + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                #     h1, h2 = h_min, h_min + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                # else:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # w_max - 2 * anch_w
                # if w_max - 2 * anch_w < w_min or True:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2
                #     # try:
                #     #     assert w_max - 2 * anch_w >= w_min
                #     # except AssertionError:
                #     #     print(w_min, w_max)
                #     #     exit(0)
                # else:
                #     w1 = np.random.randint(w_min, w_max - 2 * anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #
                #     w1 = np.random.randint(w2, w_max - anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1

                # h1, h2 = h_min - anch_h if h_min >= anch_h else 0, h_min
                # if random.random() < 0.5 and h1 >= 1:
                #     h1, h2 = h1 - 1, h2 - 1
                # anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 3
                #
                # h1, h2 = h_max + 1, h_max + 1 + anch_h if h_max + 1 + anch_h <= cur_h else cur_h
                # if random.random() < 0.5 and h2 <= cur_h - 1:
                #     h1, h2 = h1 + 1, h2 + 1
                # anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 4
                #
                # w1, w2 = w_min - anch_w if w_min >= anch_w else 0, w_min
                # if random.random() < 0.5 and w1 >= 1:
                #     w1, w2 = w1 - 1, w2 - 1
                # anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 5
                #
                # w1, w2 = w_max + 1, w_max + 1 + anch_w if w_max + 1 + anch_w <= cur_w else cur_w
                # if random.random() < 0.5 and w2 <= cur_w - 1:
                #     w1, w2 = w1 + 1, w2 + 1
                # anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 6
                #
                # all_anch_mask[r].append(anch_mask)
                # print(r, (w_min, h_min, w_max, h_max))
                # for i in anch_mask:
                #     print(i)
                # exit(0)

            # coarse_gt_mask[r] = torch.from_numpy(np.asarray(coarse_gt_mask[r])).long()
            fine_gt_mask[r] = torch.from_numpy(np.asarray(fine_gt_mask[r])).long()
            mask[r] = torch.from_numpy(np.asarray(mask[r])).long()
            # all_anch_mask[r] = torch.from_numpy(np.asarray(all_anch_mask[r])).long()
        # print(clip.size())

        return {
            'net_input': {
                'clip': clip,
                'query': query,
                'query_len': query_len,
                # 'coarse_gt_mask': coarse_gt_mask,
                'fine_gt_mask': fine_gt_mask,
                # 'anch_mask': all_anch_mask,
                'mask': mask
            },
            # 'ori_coarse_gt_box': ori_coarse_gt_box,
            'ori_fine_gt_mask': ori_fine_gt_mask
        }


if __name__ == '__main__':
    args = {
        "videoset_path": "/home/user/data/A2D/Release/videoset.csv",
        "annotation_path": "/home/user/data/A2D/Release/Annotations",
        "vocab_path": "/home/user/code/mm-2020/data/glove_a2d.bin",
        "sample_path": "/home/user/data/A2D/a2d_annotation2.txt",
        "max_num_words": 20,
    }
    dataset = JHMDB(args)
    # dataset.train_set[66]
    # exit(0)
    loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
                        pin_memory=True, collate_fn=dataset.collate_fn)
    for batch in loader:
        # for k, v in batch['net_input'].items():
        #     print(k, v.size())
        exit(0)
