import csv
import os
import random

import cv2
import h5py
import numpy as np
import torch
from gensim.models import KeyedVectors
from nltk import word_tokenize
from skimage import io
from skimage.segmentation import slic
from torch.utils.data import Dataset, DataLoader

from datasets.utils import resize_and_pad, resize


class A2DSubset(Dataset):
    def __init__(self, videos, samples, word2vec, args):
        self.videos = videos
        self.samples = samples
        self.word2vec = word2vec
        self.args = args

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        video_id, instance_id, frame_idx, query = self.samples[index]
        frame_idx = int(frame_idx)
        i3d_path = os.path.join('/home/user/data/A2D/i3d-rgb-MaxPool3d_4a_3x3/%s/%d.npy' % (video_id, frame_idx))
        h5_path = os.path.join('/home/user/data/A2D/a2d_annotation_with_instances', video_id,
                               '%05d.h5' % (frame_idx + 1))
        frame_path = os.path.join('/home/user/data/A2D/rgb-frame/%s/' % video_id)

        frames = list(map(lambda x: os.path.join(frame_path, x),
                          sorted(os.listdir(frame_path), key=lambda x: int(x[:-4]))))
        target_frame = io.imread(frames[frame_idx])

        # clip feature
        feat = np.load(i3d_path)

        # words feature
        query = [word.lower() for word in word_tokenize(query)]
        query = [self.word2vec[word] for word in query if word in self.word2vec]
        query = np.asarray(query)

        # fine-grained mask
        with h5py.File(h5_path, mode='r') as fp:
            instance = np.asarray(fp['instance'])
            all_masks = np.asarray(fp['reMask'])
            all_boxes = np.asarray(fp['reBBox']).transpose([1, 0])  # [w_min, h_min, w_max, h_max]
            all_ids = np.asarray(fp['id'])
            if video_id == 'EadxBPmQvtg' and frame_idx == 24:
                instance = instance[:-1]
            assert len(all_masks.shape) == 2 or len(all_masks.shape) == 3
            if len(all_masks.shape) == 2:
                mask = all_masks[np.newaxis]
                class_id = int(all_ids[0][0])
                coarse_gt_box = all_boxes[0]
            else:
                if instance.shape[0] < all_masks.shape[0]:
                    # print(instance, all_masks.shape)
                    tail = instance.shape[0] - all_masks.shape[0]
                    all_masks = all_masks[:tail]
                    all_boxes = all_boxes[:tail]
                    all_ids = all_ids[:, :tail]
                    # print(video_id, videos[video_id])
                assert instance.shape[0] == all_masks.shape[0]
                assert all_masks.shape[0] == all_boxes.shape[0]
                idx = instance == int(instance_id)
                mask = all_masks[idx]
                coarse_gt_box = all_boxes[idx][0]
                class_id = int(all_ids[0][idx][0])
            # class_name = '{}-{}'.format(self.actors[class_id // 10], self.actions[class_id % 10])
            # print(class_id, class_name)

            assert len(mask.shape) == 3
            assert mask.shape[0] > 0

            fine_gt_mask = np.transpose(np.asarray(mask), (0, 2, 1))[0]
            # w_min, h_min, w_max, h_max = np.asarray(coarse_gt_box, dtype=np.int32)
            # coarse_gt_mask = np.zeros_like(fine_gt_mask)
            # coarse_gt_mask[h_min:h_max + 1, w_min:w_max + 1] = 1
            # print(coarse_gt_box, target_frame.shape, fine_gt_mask.shape)
            # w_min, h_min, w_max, h_max = np.asarray(coarse_gt_box, dtype=np.int32)
            # cv2.line(target_frame, (w_min, h_min), (w_min, h_max), color=(0, 0, 255))
            # cv2.line(target_frame, (w_min, h_min), (w_max, h_min), color=(0, 0, 255))
            # cv2.line(target_frame, (w_min, h_max), (w_max, h_max), color=(0, 0, 255))
            # cv2.line(target_frame, (w_max, h_min), (w_max, h_max), color=(0, 0, 255))
            # cv2.imwrite('../debug/test.png', target_frame * coarse_gt_mask[:, :, np.newaxis])
            # exit(0)
        return {
            'query': query,
            'clip': feat,
            'target_frame': target_frame,
            'fine_gt_mask': fine_gt_mask,
            # 'coarse_gt_mask': coarse_gt_mask,
            'coarse_gt_box': coarse_gt_box,
            'sample_info': [video_id, instance_id, frame_idx, query]
        }


class A2DSlic:
    def __init__(self, args):
        self.word2vec = KeyedVectors.load_word2vec_format(args['vocab_path'], binary=True)
        self.args = args
        self.col_path = os.path.join(args['annotation_path'], 'col')  # rgb
        self.mat_path = os.path.join(args['annotation_path'], 'mat')  # matrix
        self.max_num_words = args['max_num_words']
        self.resoluation = [32, 128]
        self.n_segments = {
            32: 32,
            128: 64,
        }

        self._read_video_info()
        self._read_dataset_samples()
        self.train_set_ = A2DSubset(self.train_videos, self.train_samples, self.word2vec, args)
        self.test_set_ = A2DSubset(self.test_videos, self.test_samples, self.word2vec, args)

    def _read_video_info(self):
        self.train_videos, self.test_videos = {}, {}
        with open(self.args['videoset_path'], newline='') as fp:
            reader = csv.reader(fp, delimiter=',')
            for row in reader:
                frame_idx = list(map(lambda x: int(x[:-4]) - 1,
                                     os.listdir(os.path.join(self.col_path, row[0]))))
                frame_idx = sorted(frame_idx)
                video_info = {
                    'label': int(row[1]),
                    'timestamps': [row[2], row[3]],
                    'size': [int(row[4]), int(row[5])],  # [height, width]
                    'num_frames': int(row[6]),
                    'num_annotations': int(row[7]),
                    'frame_idx': frame_idx,
                }
                if int(row[8]) == 0:
                    self.train_videos[row[0]] = video_info
                else:
                    self.test_videos[row[0]] = video_info
        print('videos for training: {}, videos for testing: {}'.format(len(self.train_videos), len(self.test_videos)))

    def _read_dataset_samples(self):
        self.train_samples, self.test_samples = [], []

        with open(self.args['sample_path'], newline='') as fp:
            reader = csv.DictReader(fp)
            for row in reader:
                if row['video_id'] in self.train_videos:
                    self.train_samples.append([row['video_id'], row['instance_id'], row['frame_idx'], row['query']])
                else:
                    self.test_samples.append([row['video_id'], row['instance_id'], row['frame_idx'], row['query']])
        print(
            'samples for training: {}, samples for testing: {}'.format(len(self.train_samples), len(self.test_samples)))

    @property
    def train_set(self):
        return self.train_set_

    @property
    def test_set(self):
        return self.test_set_

    def collate_fn(self, samples):
        bsz = len(samples)

        # clip
        clip = torch.from_numpy(np.asarray([sample['clip'] for sample in samples])).float()

        # query
        query_len = []
        for i, sample in enumerate(samples):
            query_len.append(min(len(sample['query']), self.max_num_words))
        query = np.zeros([bsz, max(query_len), 300]).astype(np.float32)
        for i, sample in enumerate(samples):
            keep = min(len(sample['query']), query.shape[1])
            query[i, :keep] = sample['query'][:keep]
        query_len = np.asarray(query_len)
        query, query_len = torch.from_numpy(query).float(), torch.from_numpy(query_len).long()

        # ori_coarse_gt_mask = [sample['coarse_gt_mask'] for sample in samples]
        ori_coarse_gt_box = [sample['coarse_gt_box'] for sample in samples]
        ori_fine_gt_mask = [sample['fine_gt_mask'] for sample in samples]

        coarse_gt_mask = {}
        fine_gt_mask = {}
        all_anch_mask = {}
        all_segments = {}
        mask = {}

        for r in self.resoluation:
            coarse_gt_mask[r], fine_gt_mask[r], mask[r] = [], [], []
            all_anch_mask[r] = []
            all_segments[r] = []
            for i, sample in enumerate(samples):
                ori_h, ori_w = ori_fine_gt_mask[i].shape
                cur_h, cur_w = int(r / ori_w * ori_h), r
                assert ori_h <= ori_w
                a, b, pad = resize_and_pad(sample['fine_gt_mask'] * 255, r, interpolation=cv2.INTER_AREA)
                fine_gt_mask[r].append(a > 125)
                mask[r].append(b)

                cur_frame = resize(sample['target_frame'], r)
                segment = slic(cur_frame, n_segments=self.n_segments[r] - 1) + 1
                pad_config = [[pad, r - cur_frame.shape[0] - pad], [0, 0]]
                segment = np.pad(segment, pad_config, mode='constant', constant_values=0)
                all_segments[r].append(segment)

                w_min, h_min, w_max, h_max = np.asarray(ori_coarse_gt_box[i] * r / ori_w, dtype=np.int32)
                cmask = np.zeros_like(fine_gt_mask[r][-1])
                cmask[pad + h_min:pad + h_max + 1, w_min:w_max + 1] = 1
                coarse_gt_mask[r].append(cmask)

                anch_mask = np.zeros_like(fine_gt_mask[r][-1], dtype=np.int64)
                height = h_max - h_min + 1
                width = w_max - w_min + 1
                anch_h, anch_w = height // 4, width // 4
                anch_h = 1 if anch_h == 0 else anch_h
                anch_w = 1 if anch_w == 0 else anch_w
                assert anch_h > 0 and anch_w > 0

                anch_h, anch_w = 1, 1

                if random.random() < 0.5:
                    a, b = np.random.choice(range(h_min, h_max + 1), size=2, replace=(h_min == h_max))
                    h1, h2 = a, a + 1
                    anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                    h1, h2 = b, b + 1
                    anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                    # print(a, b)
                    # exit(0)
                else:
                    a, b = np.random.choice(range(w_min, w_max + 1), size=2, replace=(w_min == w_max))
                    w1, w2 = a, a + 1
                    anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                    w1, w2 = b, b + 1
                    anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # anch_h < anch_w
                # if random.random() < 0.5:
                #     h1, h2 = (h_max + h_min + 1) // 2, (h_max + h_min + 1) // 2 + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 1
                #     h1, h2 = h_min, h_min + anch_h
                #     anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 2
                # else:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2

                # w_max - 2 * anch_w
                # if w_max - 2 * anch_w < w_min or True:
                #     w1, w2 = (w_max + w_min + 1) // 2, (w_max + w_min + 1) // 2 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #     w1, w2 = w_min, w_min + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 2
                #     # try:
                #     #     assert w_max - 2 * anch_w >= w_min
                #     # except AssertionError:
                #     #     print(w_min, w_max)
                #     #     exit(0)
                # else:
                #     w1 = np.random.randint(w_min, w_max - 2 * anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1
                #
                #     w1 = np.random.randint(w2, w_max - anch_w + 1)
                #     w2 = w1 + anch_w
                #     anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 1

                h1, h2 = h_min - anch_h if h_min >= anch_h else 0, h_min
                if random.random() < 0.5 and h1 >= 1:
                    h1, h2 = h1 - 1, h2 - 1
                anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 3

                h1, h2 = h_max + 1, h_max + 1 + anch_h if h_max + 1 + anch_h <= cur_h else cur_h
                if random.random() < 0.5 and h2 <= cur_h - 1:
                    h1, h2 = h1 + 1, h2 + 1
                anch_mask[h1 + pad:h2 + pad, w_min:w_max + 1] = 4

                w1, w2 = w_min - anch_w if w_min >= anch_w else 0, w_min
                if random.random() < 0.5 and w1 >= 1:
                    w1, w2 = w1 - 1, w2 - 1
                anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 5

                w1, w2 = w_max + 1, w_max + 1 + anch_w if w_max + 1 + anch_w <= cur_w else cur_w
                if random.random() < 0.5 and w2 <= cur_w - 1:
                    w1, w2 = w1 + 1, w2 + 1
                anch_mask[pad + h_min:pad + h_max + 1, w1:w2] = 6

                all_anch_mask[r].append(anch_mask)
                # print(r, (w_min, h_min, w_max, h_max))
                # for i in anch_mask:
                #     print(i)
                # exit(0)

            coarse_gt_mask[r] = torch.from_numpy(np.asarray(coarse_gt_mask[r])).long()
            fine_gt_mask[r] = torch.from_numpy(np.asarray(fine_gt_mask[r])).long()
            mask[r] = torch.from_numpy(np.asarray(mask[r])).long()
            all_anch_mask[r] = torch.from_numpy(np.asarray(all_anch_mask[r])).long()
            all_segments[r] = torch.from_numpy(np.asarray(all_segments[r])).long()
        # print(clip.size())

        return {
            'net_input': {
                'clip': clip,
                'query': query,
                'query_len': query_len,
                'coarse_gt_mask': coarse_gt_mask,
                'fine_gt_mask': fine_gt_mask,
                'anch_mask': all_anch_mask,
                'segment': all_segments,
                'mask': mask
            },
            'ori_coarse_gt_box': ori_coarse_gt_box,
            'ori_fine_gt_mask': ori_fine_gt_mask
        }


if __name__ == '__main__':
    args = {
        "videoset_path": "/home/user/data/A2D/Release/videoset.csv",
        "annotation_path": "/home/user/data/A2D/Release/Annotations",
        "vocab_path": "/home/user/code/mm-2020/data/glove_a2d.bin",
        "sample_path": "/home/user/data/A2D/a2d_annotation2.txt",
        "max_num_words": 20,
    }
    dataset = A2DSlic(args)
    # dataset.train_set[66]
    # exit(0)
    loader = DataLoader(dataset.train_set, batch_size=4, shuffle=True, num_workers=1,
                        pin_memory=True, collate_fn=dataset.collate_fn)
    for batch in loader:
        # for k, v in batch['net_input'].items():
        #     print(k, v.size())
        exit(0)
