import os

import cv2
import numpy as np
import scipy.io as scio
import torch
from PIL import Image
from gensim.models import KeyedVectors
from nltk import word_tokenize
from skimage import io
from skimage.segmentation import slic
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from bert_embedding import BertEmbedding
from nltk.corpus import stopwords, wordnet
from datasets.utils import resize_and_pad, resize
#from utils import resize_and_pad, resize

class JHMDBSubset(Dataset):
    def __init__(self, videos, samples, word2vec, args):
        self.videos = videos
        self.samples = samples
        self.word2vec = word2vec
        self.args = args

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        video_id, frame_idx = self.samples[index]
        frame_idx = int(frame_idx)
        query = self.videos[video_id]['query']
        class_ = self.videos[video_id]['class']

        i3d_path = os.path.join(
            '/home/user/data/J-HMDB/i3d-rgb-MaxPool3d_4a_3x3/%s/%d.npy' % (video_id, frame_idx))
        frame_path = os.path.join('/home/user/data/J-HMDB/Rename_Images/%s/%s' %
                                  (class_, video_id))
        need = []
        for i in os.listdir(frame_path):
            if i == '.AppleDouble' or i == '.DS_Store':
                continue
            need.append(i)
        frames = list(map(lambda x: os.path.join(frame_path, x),
                          sorted(need, key=lambda x: int(x[:-4]))))
        target_frame = io.imread(frames[frame_idx])
        mean_pix = np.asarray([120.39586422, 115.59361427, 104.54012653]) / 255.0
        std_pix = np.asarray([70.68188272, 68.27635443, 72.54505529]) / 255.0
        limit_size = 320
        target_frame1, _, _ = resize_and_pad(target_frame, limit_size=limit_size)
        target_frame1 = transforms.Compose([
            lambda x: np.asarray(x),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean_pix, std=std_pix)
        ])(target_frame1)

        # clip feature
        # feat = np.load(i3d_path)
        step = 1
        all_frames = [i for i in range(frame_idx - 4 * step, frame_idx + 4 * step, step)]
        for i in range(len(all_frames)):
            if all_frames[i] < 0:
                all_frames[i] = 0
            elif all_frames[i] >= len(frames):
                all_frames[i] = len(frames) - 1
        all_frames = np.asarray(frames)[all_frames]
        all_frames1 = []

        for i in all_frames:
            frame = cv2.imread(i)
            image = frame[:, :, [2, 1, 0]]
            image = Image.fromarray(image)

            image = np.asarray(image)
            # image_norm = 2 * (image / 255.0) - 1
            image, _, _ = resize_and_pad(image, limit_size=limit_size)
            # image_norm = 2 * ((image - image.min() + 1e-10) / (image.max() - image.min() + 1e-10)) - 1
            # image_norm = (image - mean_pix[np.newaxis, np.newaxis, :]) / std_pix[np.newaxis, np.newaxis, :]
            image_norm = 2 * (image / 255.0) - 1
            # image_norm = image / 255.0
            all_frames1.append(image_norm)
        # clip feature
        # feat = np.load(i3d_path)
        feat = np.asarray(all_frames1).transpose((3, 0, 1, 2))

        # words feature
        # query = [word.lower() for word in word_tokenize(query)]
        # query = [self.word2vec[word] for word in query if word in self.word2vec]
        # query = np.asarray(query)
        query = query.lower()

        # fine-grained mask
        gt_path = '/home/user/data/J-HMDB/puppet_mask/%s/%s/puppet_mask.mat' % (class_, video_id)
        fine_gt_mask = scio.loadmat(gt_path)['part_mask']
        try:
            assert fine_gt_mask.shape[2] == len(frames)
        except AssertionError:
            if frame_idx >= fine_gt_mask.shape[2]:
                frame_idx = fine_gt_mask.shape[2] - 1
            #print(class_, video_id, fine_gt_mask.shape, len(frames))
            # print(frames)
            # exit(1)
        fine_gt_mask = fine_gt_mask[:, :, frame_idx]

        return {
            'query': query,
            'clip': feat,
            'target_frame1': target_frame1,
            'target_frame': target_frame,
            'fine_gt_mask': fine_gt_mask,
            'sample_info': [video_id, class_, frame_idx, self.videos[video_id]['query'], _]
        }


class JHMDBSlicRGB:
    def __init__(self, args):
        self.word2vec = KeyedVectors.load_word2vec_format(args['vocab_path'], binary=True)
        self.args = args
        self.max_num_words = args['max_num_words']
        #self.resolution = [8, 16, 32, 64, 128, 256]
        #
        self.resolution = [10, 20, 40, 80, 160, 320]
        # self.n_segments = {
        #     self.resolution[0]: 16,
        #     self.resolution[1]: 32,
        #     self.resolution[2]: 48,
        #     self.resolution[3]: 64,
        #     self.resolution[4]: 80,
        #     self.resolution[5]: 96,
        # }
        self.n_segments = {
            self.resolution[0]: 16,
            self.resolution[1]: 32,
            self.resolution[2]: 32,
            self.resolution[3]: 32,
            self.resolution[4]: 32,
            self.resolution[5]: 32,
        }


        self.bert_embedding = BertEmbedding()
        import pickle
        with open('cnt.pkl', 'rb') as fp:
            self.id2idx = pickle.load(fp)
        self._read_video_info()
        self._read_dataset_samples()
        self.train_set_ = JHMDBSubset(self.video_info, self.train_samples, self.word2vec, args)
        self.test_set_ = JHMDBSubset(self.video_info, self.test_samples, self.word2vec, args)

    def _read_video_info(self):
        import json
        with open('/home/user/data/J-HMDB/video_info.json', 'r') as fp:
            self.video_info = json.load(fp)

    def _read_dataset_samples(self):
        self.split_path = '/home/user/data/J-HMDB/splits/'
        self.train_samples, self.test_samples = [], []

        for split in os.listdir(self.split_path):
            with open(os.path.join(self.split_path, split)) as fp:
                for line in fp:
                    video, flag = line.split()
                    if flag == '1':
                        for i in range(len(self.video_info[video[:-4]]['frames'])):
                            self.train_samples.append((video[:-4], i))
                    else:
                        # idx = np.linspace(0, len(self.video_info[video[:-4]]['frames']) - 1, num=3, endpoint=True)
                        idx = np.linspace(4, len(self.video_info[video[:-4]]['frames']) - 5, num=3,
                                          endpoint=True).astype(np.int32)
                        idx = idx.astype(np.int64)
                        #print(video, idx, len(self.video_info[video[:-4]]['frames']))
                        for i in idx:
                            self.test_samples.append((video[:-4], i))

        print(
            'samples for training: {}, samples for testing: {}'.format(len(self.train_samples), len(self.test_samples)))

    @property
    def train_set(self):
        return self.train_set_

    @property
    def test_set(self):
        return self.test_set_

    def collate_fn(self, samples):
        bsz = len(samples)

        # clip
        clip = torch.from_numpy(np.asarray([sample['clip'] for sample in samples])).float()

        # query
        # query_len = []
        # for i, sample in enumerate(samples):
        #     query_len.append(min(len(sample['query']), self.max_num_words))
        # query = np.zeros([bsz, max(query_len), 300]).astype(np.float32)
        # for i, sample in enumerate(samples):
        #     keep = min(len(sample['query']), query.shape[1])
        #     query[i, :keep] = sample['query'][:keep]
        # query_len = np.asarray(query_len)
        # query, query_len = torch.from_numpy(query).float(), torch.from_numpy(query_len).long()
        result = self.bert_embedding([sample['query'] for sample in samples])
        query = []
        query_words = []
        for a, b in result:
            words = []
            words_emb = []
            for word, emb in zip(a, b):
                idx = self.bert_embedding.vocab.token_to_idx[word]
                if idx in self.id2idx and idx != 0:
                    words_emb.append(emb)
                    words.append(self.id2idx[idx])
            query.append(np.asarray(words_emb))
            query_words.append(words)

        query_len = []
        for i, sample in enumerate(query):
            query_len.append(min(len(sample), self.max_num_words))
        query1 = np.zeros([bsz, max(query_len), 768]).astype(np.float32)
        query_idx = np.zeros([bsz, max(query_len)]).astype(np.float32)
        for i, sample in enumerate(query):
            keep = min(sample.shape[0], query1.shape[1])
            query1[i, :keep] = sample[:keep]
            query_idx[i, :keep] = query_words[i][:keep]
        query_len = np.asarray(query_len)
        query, query_len = torch.from_numpy(query1).float(), torch.from_numpy(query_len).long()
        # print(query.size())
        # print(query_len)
        # exit(0)
        query_idx = torch.from_numpy(query_idx).long()

        # ori_coarse_gt_mask = [sample['coarse_gt_mask'] for sample in samples]
        # ori_coarse_gt_box = [sample['coarse_gt_box'] for sample in samples]
        ori_fine_gt_mask = [sample['fine_gt_mask'] for sample in samples]

        coarse_gt_mask = {}
        fine_gt_mask = {}
        all_anch_mask = {}
        all_segments = {}
        mask = {}

        for r in self.resolution:
            coarse_gt_mask[r], fine_gt_mask[r], mask[r] = [], [], []
            all_anch_mask[r] = []
            all_segments[r] = []
            for i, sample in enumerate(samples):
                ori_h, ori_w = ori_fine_gt_mask[i].shape
                cur_h, cur_w = int(r / ori_w * ori_h), r
                assert ori_h <= ori_w
                a, b, pad = resize_and_pad(sample['fine_gt_mask'] * 255, r, interpolation=cv2.INTER_AREA)
                fine_gt_mask[r].append(a > 50)
                mask[r].append(b)

                cur_frame = resize(sample['target_frame'], r)
                segment = slic(cur_frame, n_segments=self.n_segments[r] - 1) + 1
                pad_config = [[pad, r - cur_frame.shape[0] - pad], [0, 0]]
                segment = np.pad(segment, pad_config, mode='constant', constant_values=0)
                all_segments[r].append(segment)

                # w_min, h_min, w_max, h_max = np.asarray(ori_coarse_gt_box[i] * r / ori_w, dtype=np.int32)
                # cmask = np.zeros_like(fine_gt_mask[r][-1])
                # cmask[pad + h_min:pad + h_max + 1, w_min:w_max + 1] = 1
                # coarse_gt_mask[r].append(cmask)
                #
                # anch_mask = np.zeros_like(fine_gt_mask[r][-1], dtype=np.int64)
                # height = h_max - h_min + 1
                # width = w_max - w_min + 1
                # anch_h, anch_w = height // 4, width // 4
                # anch_h = 1 if anch_h == 0 else anch_h
                # anch_w = 1 if anch_w == 0 else anch_w
                # assert anch_h > 0 and anch_w > 0
                #
                # anch_h, anch_w = 1, 1
                #
                # if random.random() < 0.5:
                #     a, b = np.random.choice(range(h_min, h_max + 1), size=2, replace=(h_min == h_max))
                #     h1, h2 = a, a + 1
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                #     h1, h2 = b, b + 1
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                #     # print(a, b)
                #     # exit(0)
                # else:
                #     a, b = np.random.choice(range(w_min, w_max + 1), size=2, replace=(w_min == w_max))
                #     w1, w2 = a, a + 1
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = b, b + 1
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # anch_h < anch_w
                # if random.random() < 0.5:
                #     h1, h2 = (h_max + h_min + 1) // 2, (h_max + h_min + 1) // 2 + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                #     h1, h2 = h_min, h_min + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                # else:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # w_max - 2 * anch_w
                # if w_max - 2 * anch_w < w_min or True:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2
                #     # try:
                #     #     assert w_max - 2 * anch_w >= w_min
                #     # except AssertionError:
                #     #     print(w_min, w_max)
                #     #     exit(0)
                # else:
                #     w1 = np.random.randint(w_min, w_max - 2 * anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #
                #     w1 = np.random.randint(w2, w_max - anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1

                # h1, h2 = h_min - anch_h if h_min >= anch_h else 0, h_min
                # if random.random() < 0.5 and h1 >= 1:
                #     h1, h2 = h1 - 1, h2 - 1
                # anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 3
                #
                # h1, h2 = h_max + 1, h_max + 1 + anch_h if h_max + 1 + anch_h <= cur_h else cur_h
                # if random.random() < 0.5 and h2 <= cur_h - 1:
                #     h1, h2 = h1 + 1, h2 + 1
                # anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 4
                #
                # w1, w2 = w_min - anch_w if w_min >= anch_w else 0, w_min
                # if random.random() < 0.5 and w1 >= 1:
                #     w1, w2 = w1 - 1, w2 - 1
                # anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 5
                #
                # w1, w2 = w_max + 1, w_max + 1 + anch_w if w_max + 1 + anch_w <= cur_w else cur_w
                # if random.random() < 0.5 and w2 <= cur_w - 1:
                #     w1, w2 = w1 + 1, w2 + 1
                # anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 6
                #
                # all_anch_mask[r].append(anch_mask)
                # print(r, (w_min, h_min, w_max, h_max))
                # for i in anch_mask:
                #     print(i)
                # exit(0)

            # coarse_gt_mask[r] = torch.from_numpy(np.asarray(coarse_gt_mask[r])).long()
            fine_gt_mask[r] = torch.from_numpy(np.asarray(fine_gt_mask[r])).long()
            mask[r] = torch.from_numpy(np.asarray(mask[r])).long()
            # all_anch_mask[r] = torch.from_numpy(np.asarray(all_anch_mask[r])).long()
            all_segments[r] = torch.from_numpy(np.asarray(all_segments[r])).long()
        # print(clip.size())

        return {
            'net_input': {
                'clip': clip,
                'query': query,
                'query_len': query_len,
                'query_idx': query_idx,
                # 'coarse_gt_mask': coarse_gt_mask,
                'fine_gt_mask': fine_gt_mask,
                # 'anch_mask': all_anch_mask,
                'segment': all_segments,
                'target_frame': torch.stack([sample['target_frame1'] for sample in samples], 0).float(),
                'mask': mask
            },
            'target_frame': [sample['target_frame'] for sample in samples],
            # 'ori_coarse_gt_box': ori_coarse_gt_box,
            'ori_fine_gt_mask': ori_fine_gt_mask,
            'video_info': [sample['sample_info'] for sample in samples]
        }


# if __name__ == '__main__':
#     args = {
#         "videoset_path": "/home/user/data/A2D/Release/videoset.csv",
#         "annotation_path": "/home/user/data/A2D/Release/Annotations",
#         "vocab_path": "/home/user/Archive/data/glove_a2d.bin",
#         "sample_path": "/home/user/data/A2D/a2d_annotation2.txt",
#         "max_num_words": 20,
#     }
#     dataset = JHMDBSlicRGB(args)
#     # dataset.train_set[66]
#     # exit(0)
#     loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
#                         pin_memory=True, collate_fn=dataset.collate_fn)
#     for batch in loader:
#         for k, v in batch['net_input'].items():
#             print(k, v.size())
#         exit(0)
