import csv
import os
import random


import cv2
import h5py
import numpy as np
import torch
from PIL import Image
from bert_embedding import BertEmbedding
from gensim.models import KeyedVectors
from nltk import word_tokenize
from nltk.corpus import stopwords, wordnet
from skimage import io
from skimage.segmentation import slic
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import torchvision.transforms.functional as F

from utils import resize_and_pad, resize
#from utils import resize_and_pad, resize
class A2DSubset(Dataset):
    def __init__(self, videos, samples, word2vec, args, transforms=None, train=False):
        self.videos = videos
        self.samples = samples
        self.word2vec = word2vec
        self.args = args
        self.trans = transforms

        self.is_full = np.zeros(len(self.samples), dtype=np.int64)
        np.random.seed(88)
        self.full_videos = list(videos.keys())
        np.random.shuffle(self.full_videos)
        self.full_videos = self.full_videos[:int(0.5 * len(self.full_videos))]
        self.train = train


        self.rgb = {}
        with open('/home/user/Archive/color') as fp:
            for line in fp.readlines():
                tmp = line.strip().split()
                self.rgb[int(tmp[1])] = np.asarray([int(tmp[3]), int(tmp[4]), int(tmp[5])])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        video_id, instance_id, frame_idx, query = self.samples[index]
        query = query.lower()
        frame_idx = int(frame_idx)
        # is_full = 1 if video_id in self.full_videos else 0
        is_full = 1
        # is_full = 1 if index in self.idx else 0
        i3d_path = os.path.join('/home/user/data/A2D/i3d-rgb-MaxPool3d_4a_3x3/%s/%d.npy' % (video_id, frame_idx))
        h5_path = os.path.join('/home/user/data/A2D/a2d_annotation_with_instances', video_id,
                               '%05d.h5' % (frame_idx + 1))
        if not os.path.exists(h5_path):
            h5_path = os.path.join('/home/user/data/A2D/a2d_annotation_with_instances', video_id,
                                   '%05d.h5' % (24 + 1))
        frame_path = os.path.join('/home/user/data/A2D/Release/pngs320H/', video_id)
        # frame_path = os.path.join('/home1/user/data/A2D/rgb-frame/%s/' % video_id)

        frames = list(map(lambda x: os.path.join(frame_path, x),
                          sorted(os.listdir(frame_path))))
        # print(len(frames), self.videos['num_frames'])

        assert len(frames) == self.videos[video_id]['num_frames']

        try:
            target_frame = io.imread(frames[frame_idx])
        except BaseException as e:
            print(video_id, frame_idx, frame_path)
            exit(0)
        mean_pix = np.asarray([120.39586422, 115.59361427, 104.54012653]) / 255.0
        std_pix = np.asarray([70.68188272, 68.27635443, 72.54505529]) / 255.0

        limit_size = 320

        target_frame1, _, _ = resize_and_pad(target_frame, limit_size=limit_size)
        target_frame1 = transforms.Compose([
            lambda x: np.asarray(x),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean_pix, std=std_pix)
        ])(target_frame1)

        step = 1
        all_frames = [
            frame_idx - 3,
            frame_idx - 2,
            frame_idx - 1,
            frame_idx,
            frame_idx + 1,
            frame_idx + 2,
            frame_idx + 3,
            frame_idx + 4,
        ]
        # all_frames = [i for i in range(frame_idx - 4 * step, frame_idx + 4 * step, step)]
        for i in range(len(all_frames)):
            if all_frames[i] < 0:
                all_frames[i] = 0
            elif all_frames[i] >= len(frames):
                all_frames[i] = len(frames) - 1
        all_frames = np.asarray(frames)[all_frames]
        all_frames1 = []

        # flow_feat = []
        prev = None

        def compute_TVL1(prev, curr, bound=15):
            TVL1 = cv2.optflow.DualTVL1OpticalFlow_create()
            flow = TVL1.calc(prev, curr, None)
            flow = np.clip(flow, -20, 20)  # default values are +20 and -20
            assert flow.dtype == np.float32
            flow = (flow + bound) * (255.0 / (2 * bound))
            flow = np.round(flow).astype(int)
            flow[flow >= 255] = 255
            flow[flow <= 0] = 0
            return flow

        flip = np.random.random() < 0.5 and self.train

        for i in all_frames:
            # print(i)
            #  frame = cv2.imread(i)
            frame = io.imread(i)
            # image = frame[:, :, [2, 1, 0]]
            image = frame

            image = Image.fromarray(image)
            if flip:
                image = np.asarray(F.hflip(image))
            else:
                image = np.asarray(image)

            # if self.trans is not None:
            #     image = self.trans(image)
            # else:
            #     image = np.asarray(image)
            # print(image.shape)
            image, valid_mask, _ = resize_and_pad(image, limit_size=limit_size)

            # if prev is None:
            #     prev = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            # curr = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            # image_flow = compute_TVL1(prev, curr)
            # image_flow = np.divide(2 * (image_flow - image_flow.min()), image_flow.max() - image_flow.min() + 1e-10) - 1
            # flow_feat.append(image_flow)
            # prev = curr

            # image_norm = 2 * ((image - image.min() + 1e-10) / (image.max() - image.min() + 1e-10)) - 1
            # image_norm = (image - mean_pix[np.newaxis, np.newaxis, :]) / std_pix[np.newaxis, np.newaxis, :]
            image_norm = 2 * (image / 255.0) - 1
            image_norm = image_norm * valid_mask[:, :, np.newaxis]
            # image_norm = image / 255.0
            all_frames1.append(image_norm)
        # clip feature
        # feat = np.load(i3d_path)
        feat = np.asarray(all_frames1).transpose((3, 0, 1, 2))

        # flow_feat = np.asarray(flow_feat).transpose((3, 0, 1, 2))

        def get_wordnet_pos(tag):
            if tag.startswith('J'):
                return wordnet.ADJ
            elif tag.startswith('V'):
                return wordnet.VERB
            elif tag.startswith('N'):
                return wordnet.NOUN
            elif tag.startswith('R'):
                return wordnet.ADV
            else:
                return None

        # words feature
        stopset = set(stopwords.words('english'))

        # fine-grained mask
        with h5py.File(h5_path, mode='r') as fp:
            instance = np.asarray(fp['instance'])
            all_masks = np.asarray(fp['reMask'])
            if len(all_masks.shape) == 3 and instance.shape[0] != all_masks.shape[0]:
                print(video_id, frame_idx + 1, instance.shape, all_masks.shape)

            all_boxes = np.asarray(fp['reBBox']).transpose([1, 0])  # [w_min, h_min, w_max, h_max]
            all_ids = np.asarray(fp['id'])
            # if video_id == 'EadxBPmQvtg' and frame_idx == 24:
            #     instance = instance[:-1]
            assert len(all_masks.shape) == 2 or len(all_masks.shape) == 3
            if len(all_masks.shape) == 2:
                mask = all_masks[np.newaxis]
                class_id = int(all_ids[0][0])
                coarse_gt_box = all_boxes[0]
            else:
                instance_id = int(instance_id)
                idx = np.where(instance == instance_id)[0][0]

                mask = all_masks[idx]
                coarse_gt_box = all_boxes[idx]
                class_id = int(all_ids[0][idx])
                mask = mask[np.newaxis]



            assert len(mask.shape) == 3
            assert mask.shape[0] > 0

            fine_gt_mask = np.transpose(np.asarray(mask), (0, 2, 1))[0]
            if flip:
                fine_gt_mask = Image.fromarray(fine_gt_mask)
                fine_gt_mask = np.asarray(F.hflip(fine_gt_mask))

        return {
            'query': query,
            # 'query_idx': query_idx,
            'clip': feat,
            # 'clip_flow': flow_feat,
            'target_frame': target_frame,
            'fine_gt_mask': fine_gt_mask,
            'target_frame1': target_frame1,
            'full': is_full,
            # 'coarse_gt_mask': coarse_gt_mask,
            'coarse_gt_box': coarse_gt_box,
            'sample_info': [video_id, instance_id, frame_idx, self.samples[index][-1], self.rgb[class_id]]
        }


class A2DSlicRGB:
    def __init__(self, args):
        self.word2vec = KeyedVectors.load_word2vec_format(args['vocab_path'], binary=True)
        # self.word2vec.save()

        self.args = args
        self.col_path = os.path.join(args['annotation_path'], 'col')  # rgb
        self.mat_path = os.path.join(args['annotation_path'], 'mat')  # matrix
        self.max_num_words = args['max_num_words']
        self.resolution = [8, 16, 32, 64, 128, 256]
        self.resolution = [10, 20, 40, 80, 160, 320]
        self.n_segments = {
            self.resolution[0]: 16,
            self.resolution[1]: 32,
            self.resolution[2]: 32,
            self.resolution[3]: 32,
            self.resolution[4]: 32,
            self.resolution[5]: 32,
        }
        import pickle
        with open('cnt.pkl', 'rb') as fp:
            self.id2idx = pickle.load(fp)

        self.bert_embedding = BertEmbedding()
        self._read_video_info()
        self._read_dataset_samples()
        self.train_set_ = A2DSubset(self.train_videos, self.train_samples, self.word2vec, args,
                                    transforms=transforms.Compose([
                                        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                                        transforms.RandomHorizontalFlip(),
                                        lambda x: np.asarray(x),
                                    ]), train=True)
        self.test_set_ = A2DSubset(self.test_videos, self.test_samples, self.word2vec, args)

    def _read_video_info(self):
        self.train_videos, self.test_videos = {}, {}
        with open(self.args['videoset_path'], newline='') as fp:
            reader = csv.reader(fp, delimiter=',')
            for row in reader:
                frame_idx = list(map(lambda x: int(x[:-4]) - 1,
                                     os.listdir(os.path.join(self.col_path, row[0]))))
                frame_idx = sorted(frame_idx)
                # print(frame_idx)
                # exit(0)
                video_info = {
                    'label': int(row[1]),
                    'timestamps': [row[2], row[3]],
                    'size': [int(row[4]), int(row[5])],  # [height, width]
                    'num_frames': int(row[6]),
                    'num_annotations': int(row[7]),
                    'frame_idx': frame_idx,
                }
                if int(row[8]) == 0:
                    self.train_videos[row[0]] = video_info
                else:
                    self.test_videos[row[0]] = video_info

    def _read_dataset_samples(self):
        self.train_samples, self.test_samples = [], []
        self.train_videos_set = set()
        self.test_videos_set = set()
        self.all_query = set()
        with open(self.args['sample_path'], newline='') as fp:
            reader = csv.DictReader(fp)
            from collections import defaultdict
            video2frame = defaultdict(list)
            rows = []
            for row in reader:
                rows.append(row)
                video2frame[(row['video_id'], row['query'])].append(row['frame_idx'])
            for row in rows:
                if row['video_id'] in self.train_videos:
                    self.train_samples.append([row['video_id'], row['instance_id'], row['frame_idx'], row['query']])
                    self.train_videos_set.add(row['video_id'])
                else:
                    l = video2frame[(row['video_id'], row['query'])]
                    # if l[len(l) >> 1] != row['frame_idx']:
                    #     continue
                    self.test_samples.append([row['video_id'], row['instance_id'], row['frame_idx'], row['query']])
                    self.test_videos_set.add(row['video_id'])
                    # self.test_samples.append([row['video_id'], row['instance_id'], 36, row['query']])
                    # self.test_samples.append([row['video_id'], row['instance_id'], 61, row['query']])
                self.all_query.add((row['video_id'], row['query']))
        print('number of sentences: {}'.format(len(self.all_query)))
        print('videos for training: {}, videos for testing: {}'.format(len(self.train_videos_set),
                                                                       len(self.test_videos_set)))
        print(
            'samples for training: {}, samples for testing: {}'.format(len(self.train_samples), len(self.test_samples)))
        # exit(0)

    @property
    def train_set(self):
        return self.train_set_

    @property
    def test_set(self):
        return self.test_set_

    def collate_fn(self, samples):
        bsz = len(samples)

        # clip
        clip = torch.from_numpy(np.asarray([sample['clip'] for sample in samples])).float()
        # clip_flow = torch.from_numpy(np.asarray([sample['clip_flow'] for sample in samples])).float()

        full = torch.from_numpy(np.asarray([sample['full'] for sample in samples])).long()

        result = self.bert_embedding([sample['query'] for sample in samples])
        query = []
        query_words = []
        for a, b in result:
            words = []
            words_emb = []
            for word, emb in zip(a, b):
                idx = self.bert_embedding.vocab.token_to_idx[word]
                if idx in self.id2idx and idx != 0:
                    words_emb.append(emb)
                    words.append(self.id2idx[idx])
            query.append(np.asarray(words_emb))
            query_words.append(words)


        query_len = []
        for i, sample in enumerate(query):
            query_len.append(min(len(sample), self.max_num_words))
        query1 = np.zeros([bsz, max(query_len), 768]).astype(np.float32)
        query_idx = np.zeros([bsz, max(query_len)]).astype(np.float32)
        for i, sample in enumerate(query):
            keep = min(sample.shape[0], query1.shape[1])
            query1[i, :keep] = sample[:keep]
            query_idx[i, :keep] = query_words[i][:keep]
        query_len = np.asarray(query_len)
        query, query_len = torch.from_numpy(query1).float(), torch.from_numpy(query_len).long()
        # print(query.size())
        # print(query_len)
        # exit(0)
        query_idx = torch.from_numpy(query_idx).long()

        # ori_coarse_gt_mask = [sample['coarse_gt_mask'] for sample in samples]
        ori_coarse_gt_box = [sample['coarse_gt_box'] for sample in samples]
        ori_fine_gt_mask = [sample['fine_gt_mask'] for sample in samples]

        coarse_gt_mask = {}
        fine_gt_mask = {}
        all_anch_mask = {}
        all_segments = {}
        mask = {}

        for r in self.resolution:
            coarse_gt_mask[r], fine_gt_mask[r], mask[r] = [], [], []
            all_anch_mask[r] = []
            all_segments[r] = []
            for i, sample in enumerate(samples):
                ori_h, ori_w = ori_fine_gt_mask[i].shape
                cur_h, cur_w = int(r / ori_w * ori_h), r
                assert ori_h <= ori_w
                a, b, pad = resize_and_pad(sample['fine_gt_mask'] * 255, r, interpolation=cv2.INTER_LINEAR)
                fine_gt_mask[r].append(a > 100)
                mask[r].append(b)

                cur_frame = resize(sample['target_frame'], r)
                segment = slic(cur_frame, n_segments=self.n_segments[r] - 1) + 1
                pad_config = [[pad, r - cur_frame.shape[0] - pad], [0, 0]]
                segment = np.pad(segment, pad_config, mode='constant', constant_values=0)
                all_segments[r].append(segment)


                w_min, h_min, w_max, h_max = np.asarray(ori_coarse_gt_box[i] * r / ori_w, dtype=np.int32)
                cmask = np.zeros_like(fine_gt_mask[r][-1])
                cmask[pad + h_min:pad + h_max + 1, w_min:w_max + 1] = 1
                coarse_gt_mask[r].append(cmask)

                anch_mask = np.zeros_like(fine_gt_mask[r][-1], dtype=np.int64)
                height = h_max - h_min + 1
                width = w_max - w_min + 1
                anch_h, anch_w = height // 4, width // 4
                anch_h = 1 if anch_h == 0 else anch_h
                anch_w = 1 if anch_w == 0 else anch_w
                assert anch_h > 0 and anch_w > 0

                anch_h, anch_w = 1, 1

                if random.random() < 0.5:
                    a, b = np.random.choice(range(h_min, h_max + 1), size=2, replace=(h_min == h_max))
                    h1, h2 = a, a + 1
                    anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                    h1, h2 = b, b + 1
                    anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                    # print(a, b)
                    # exit(0)
                else:
                    a, b = np.random.choice(range(w_min, w_max + 1), size=2, replace=(w_min == w_max))
                    w1, w2 = a, a + 1
                    anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                    w1, w2 = b, b + 1
                    anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # anch_h < anch_w
                # if random.random() < 0.5:
                #     h1, h2 = (h_max + h_min + 1) // 2, (h_max + h_min + 1) // 2 + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                #     h1, h2 = h_min, h_min + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                # else:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # w_max - 2 * anch_w
                # if w_max - 2 * anch_w < w_min or True:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2
                #     # try:
                #     #     assert w_max - 2 * anch_w >= w_min
                #     # except AssertionError:
                #     #     print(w_min, w_max)
                #     #     exit(0)
                # else:
                #     w1 = np.random.randint(w_min, w_max - 2 * anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #
                #     w1 = np.random.randint(w2, w_max - anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1

                h1, h2 = h_min - anch_h if h_min >= anch_h else 0, h_min
                if random.random() < 0.5 and h1 >= 1:
                    h1, h2 = h1 - 1, h2 - 1
                anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 3

                h1, h2 = h_max + 1, h_max + 1 + anch_h if h_max + 1 + anch_h <= cur_h else cur_h
                if random.random() < 0.5 and h2 <= cur_h - 1:
                    h1, h2 = h1 + 1, h2 + 1
                anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 4

                w1, w2 = w_min - anch_w if w_min >= anch_w else 0, w_min
                if random.random() < 0.5 and w1 >= 1:
                    w1, w2 = w1 - 1, w2 - 1
                anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 5

                w1, w2 = w_max + 1, w_max + 1 + anch_w if w_max + 1 + anch_w <= cur_w else cur_w
                if random.random() < 0.5 and w2 <= cur_w - 1:
                    w1, w2 = w1 + 1, w2 + 1
                anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 6

                all_anch_mask[r].append(anch_mask)
                # print(r, (w_min, h_min, w_max, h_max))
                # for i in anch_mask:
                #     print(i)
                # exit(0)

            coarse_gt_mask[r] = torch.from_numpy(np.asarray(coarse_gt_mask[r])).long()
            fine_gt_mask[r] = torch.from_numpy(np.asarray(fine_gt_mask[r])).float()
            mask[r] = torch.from_numpy(np.asarray(mask[r])).long()
            all_anch_mask[r] = torch.from_numpy(np.asarray(all_anch_mask[r])).long()
            all_segments[r] = torch.from_numpy(np.asarray(all_segments[r])).long()
        # print(clip.size())

        return {
            'net_input': {
                'clip': clip,
                # 'clip_flow': clip_flow,
                'query': query,
                'query_idx': query_idx,
                'query_len': query_len,
                'coarse_gt_mask': coarse_gt_mask,
                'fine_gt_mask': fine_gt_mask,
                'anch_mask': all_anch_mask,
                'segment': all_segments,
                'target_frame': torch.stack([sample['target_frame1'] for sample in samples], 0).float(),
                'mask': mask,
                'full': full
            },
            'target_frame': [sample['target_frame'] for sample in samples],
            'ori_coarse_gt_box': ori_coarse_gt_box,
            'ori_fine_gt_mask': ori_fine_gt_mask,
            'video_info': [sample['sample_info'] for sample in samples]
        }


if __name__ == '__main__':
    args = {
        "videoset_path": "/home/user/data/A2D/Release/videoset.csv",
        "annotation_path": "/home/user/data/A2D/Release/Annotations",
        "vocab_path": "/home/user/data/A2D/a2d_annotation.txt",
        "sample_path": "/home/user/data/A2D/a2d_annotation2.txt",
        "max_num_words": 20,
    }
    dataset = A2DSlicRGB(args)
    dataset.train_set[66]
    # exit(0)
    import pickle
    loader = DataLoader(dataset.train_set, batch_size=32, shuffle=True, num_workers=1,
                        pin_memory=True, collate_fn=dataset.collate_fn)
    id2idx = {}
    id2idx[0] = 0
    for batch in loader:
        print(batch['net_input']['clip'])
        # for i in batch['net_input']['query_idx']:
        #     for j in i:
        #         if int(j) not in id2idx:
        #             id2idx[int(j)] = len(id2idx)
        # print(len(id2idx))

        # for k, v in batch['net_input'].items():
        #     print(k, v.size())

    with open('cnt.pkl', 'wb') as fp:
        pickle.dump(id2idx, fp)