# from memory_profiler import profile

import os
import json
import random
import math
import torch
from torch.utils.data import DataLoader
import copy

from utils import *
from loading import *

class OpenPVSGDataset():

    def __init__(self, dataset_dir, dataset_filename, device, data_percentage, cache_path, phase="train", max_vid_len=10, max_obj_per_frame=8) -> None:

        dataset_path = os.path.join(dataset_dir, dataset_filename)
        raw_dataset = json.load(open(dataset_path, 'r'))

        gpt_cache = json.load(open(cache_path, 'r'))
        dataset = []

        self.video_dirs = {}
        self.mask_dirs = {}

        self.video_dirs['vidor'] = os.path.join(dataset_dir, "VidOR/videos")
        self.mask_dirs['vidor'] = os.path.join(dataset_dir, "VidOR/masks")
        self.video_dirs['epic_kitchen'] = os.path.join(dataset_dir, "EpicKitchen/videos")
        self.mask_dirs['epic_kitchen'] = os.path.join(dataset_dir, "VidOR/masks")
        self.video_dirs['ego4d'] = os.path.join(dataset_dir, "Ego4D/videos")
        self.mask_dirs['ego4d'] = os.path.join(dataset_dir, "Ego4D/masks")

        data_lookup = {dp['video_id']: dp for dp in raw_dataset['data']}

        self.THING_CLASSES = raw_dataset['objects']['thing']  # 115
        self.STUFF_CLASSES = raw_dataset['objects']['stuff']  # 11
        self.BACKGROUND_CLASSES = ['background']
        self.CLASSES = self.THING_CLASSES + self.STUFF_CLASSES
        self.num_thing_classes = len(self.THING_CLASSES)
        self.num_stuff_classes = len(self.STUFF_CLASSES)
        self.num_classes = len(self.CLASSES)  # 126
        self.cates2id = dict(
            zip(self.CLASSES + self.BACKGROUND_CLASSES,
                range(len(self.CLASSES + self.BACKGROUND_CLASSES))))


        data_split_info = raw_dataset['split']
        checked_datapoints = json.load(open("processed_1.json", 'r'))

        for dataset_name, data_split in data_split_info.items():
            # dataset.append(data_split)
            for data_id in data_split[phase]:
                # if data_id in checked_datapoints:
                #     continue
                # if not "1006_4580824633" == data_id:
                #     continue

                for caption in data_lookup[data_id]['captions']:

                    clean_des = clean_cap(caption['description'])
                    if not clean_des in gpt_cache:
                        continue

                    video_path = os.path.join(self.video_dirs[dataset_name], f"{data_id}.mp4")
                    if not os.path.exists(video_path):
                        continue

                    datapoint = {'data_id': data_id,
                                 'caption': caption,
                                 'gpt_spec': gpt_cache[clean_des],
                                 'dataset': dataset_name,
                                 'objects': data_lookup[data_id]['objects'],
                                 'meta': data_lookup[data_id]['meta'],
                                 'relations': data_lookup[data_id]['relations']}

                    start, end = get_start_end(caption=caption)
                    if not start < end:
                        continue
                    dataset.append(datapoint)

        # Shuffle the dataset so that cutting the dataset will still give an indistribution dataset
        random.shuffle(dataset)

        # start = False
        # new_dataset = []
        # for dp in dataset:
        #     if data_id == '1025_6061530960':
        #         start = True
        #     if start:
        #         new_dataset.append(dp)
        # dataset = new_dataset

        dp_count = math.floor(data_percentage / 100 * len(dataset))
        self.dataset = dataset[:dp_count]
        # self.dataset = dataset[624:]

        self.device = device
        self.max_vid_len = max_vid_len
        self.max_obj_per_frame = max_obj_per_frame

    def  __len__(self):
        return len(self.dataset)

    def process_val(self, x, max_val):
        x = max(0, x)
        x = min(x, max_val)
        return x

    # @profile
    def __getitem__(self, i):

        datapoint = copy.deepcopy(self.dataset[i])
        vid_id = datapoint['data_id']
        dataset = datapoint['dataset']
        caption = datapoint['caption']

        video_path = os.path.join(self.video_dirs[dataset], f"{vid_id}.mp4")
        mask_dir = os.path.join(self.mask_dirs[dataset], vid_id)

        # Load video and caption
        start, end = get_start_end(caption=caption)
        assert start < end
        video = load_video(video_path, start, end)
        start = max(start, 0)
        end = min(end, start + video.shape[0])
        mask_paths = get_mask_paths(mask_dir=mask_dir, start_time=start, end_time=end)

        datapoint['start_time'] = start
        datapoint['end_time'] = end

        # Load masks
        masks = []
        for mask_path in mask_paths:
            result = load_annotations(datapoint, mask_path, self.cates2id)
            masks.append(result)
        datapoint['masks'] = masks

        new_relations = {i: [] for i in range(start, end)}
        for sub_id, obj_id, rel, time_ls in datapoint['relations']:
            for from_t, to_t in time_ls:
                lap_start, lap_end = get_overlap((from_t, to_t), (start, end))
                if not lap_start == -1:
                    for i in range(lap_start, lap_end):
                        new_relations[i].append((sub_id, obj_id, rel))
        datapoint['relations'] = list(new_relations.values())

        # Sample the video if too large
        if len(video) > self.max_vid_len:
            sample_rate = math.ceil(len(video) / self.max_vid_len)
            video = [f for i, f in enumerate(video) if i % sample_rate == 0]
            new_bboxes = [b for i, b in enumerate(datapoint['masks']) if i % sample_rate == 0]
            new_relations = [r_ls for i, r_ls in new_relations.items() if i % sample_rate == 0]
            datapoint['masks'] = new_bboxes
            datapoint['relations'] = new_relations

        # Normalize video color and shape
        reshaped_video = []
        norm_reshaped_video = []
        v_height = video[0].shape[0]
        v_width = video[0].shape[1]

        x_portion, y_portion = norm_x / v_width, norm_y / v_height

        new_masks = []

        for bboxes in datapoint['masks']:
            clean_labels = {}
            bboxes_sizes = []
            clean_labels['gt_labels'] = []
            clean_labels['gt_masks'] = []
            clean_labels['gt_instance_ids'] = []
            clean_labels['gt_bboxes'] = []

            for lb, mask, id, bbox2d in zip(bboxes['gt_labels'], bboxes['gt_masks'], bboxes['gt_instance_ids'], bboxes['gt_bboxes']):
                x1, x2, y1, y2 = bbox2d['x1'], bbox2d['x2'], bbox2d['y1'], bbox2d['y2']

                x1 = self.process_val(x1, v_width)
                x2 = self.process_val(x2, v_width)
                y1 = self.process_val(y1, v_height)
                y2 = self.process_val(y2, v_height)

                bbox2d['x1'], bbox2d['x2'], bbox2d['y1'], bbox2d['y2'] = int(x1 * x_portion), int(x2 * x_portion), int(y1 * y_portion), int(y2 * y_portion)
                assert bbox2d['x1'] <= norm_x
                assert bbox2d['x2'] <= norm_x
                assert bbox2d['y1'] <= norm_y
                assert bbox2d['y2'] <= norm_y

                if (bbox2d['y2'] > bbox2d['y1'] + 5 and bbox2d['x2'] > bbox2d['x1'] + 5):
                    size = (bbox2d['y2']  - bbox2d['y1']) * (bbox2d['x2'] - bbox2d['x1'])
                    bboxes_sizes.append((size, bbox2d, lb, mask, id))

                    # clean_labels['gt_bboxes'].append(bbox2d)
                    # clean_labels['gt_labels'].append(lb)
                    # clean_labels['gt_masks'].append(mask)
                    # clean_labels['gt_instance_ids'].append(id)
                # else:
                    # print('invalid gt bbox')
            sorted_bboxes = sorted(bboxes_sizes, key = lambda x: -x[0])
            sorted_bboxes = sorted_bboxes[:self.max_obj_per_frame]

            for size, bbox2d, lb, mask, id in sorted_bboxes:
                clean_labels['gt_bboxes'].append(bbox2d)
                clean_labels['gt_labels'].append(lb)
                clean_labels['gt_masks'].append(mask)
                clean_labels['gt_instance_ids'].append(id)

            new_masks.append(clean_labels)
        datapoint['masks'] = new_masks

        for frame in video:
            new_frame = cv2.resize(frame, (norm_x, norm_y))
            reshaped_video.append(torch.tensor(new_frame, dtype=torch.float32).to(self.device))
            shape = new_frame.shape
            new_frame = np.moveaxis(new_frame, -1, 0)
            new_frame = transform(torch.tensor(new_frame, dtype=torch.float32))
            norm_reshaped_video.append(new_frame.to(self.device))

        return datapoint, reshaped_video, norm_reshaped_video

    @staticmethod
    def collate_fn(batch):
        batched_videos = []
        batched_reshaped_raw_videos = []
        batched_captions = []

        batched_obj_pairs = []
        batched_ids = []
        batched_video_splits = []
        batched_gpt_specs = []

        batched_gt_bboxes = []
        batched_gt_masks = []
        batched_gt_obj_names = []
        batched_gt_object_rels = []

        batched_object_ids = []

        frame_ct_in_video = 0
        for data_id, (datapoint, reshaped_raw_video, video) in enumerate(batch):

            batched_reshaped_raw_videos += reshaped_raw_video
            batched_videos += (video)
            batched_ids.append(datapoint['data_id'])
            batched_captions.append(datapoint['gpt_spec']['caption'])
            # batched_gt_masks.append(datapoint['masks']['gt_masks'])
            # batched_gt_objects.append(datapoint['objects'])
            bounding_box_info = datapoint['masks']
            batched_gpt_specs.append(datapoint['gpt_spec'])
            batched_gt_object_rels.append(datapoint['relations'])

            all_obj_ids = set()
            for frame_id, frame in enumerate(bounding_box_info):
                for label in frame['gt_instance_ids']:
                    all_obj_ids.add(label)

            for frame_id, frame in enumerate(bounding_box_info):
                object_ct_in_frame = len(frame['gt_instance_ids'])
                obj_ids_in_frame = []

                batched_gt_bboxes += frame['gt_bboxes']
                batched_gt_masks += frame['gt_masks']
                batched_gt_obj_names += [(data_id, frame_id, l) for l in frame['gt_labels']]
                batched_object_ids += [(data_id, frame_id, id) for id in frame['gt_instance_ids']]
                obj_ids_in_frame = frame['gt_instance_ids']

                for oid1 in all_obj_ids:
                    for oid2 in all_obj_ids:
                        if oid1 in obj_ids_in_frame and oid2 in obj_ids_in_frame and not oid1 == oid2:
                            batched_obj_pairs.append((data_id, frame_id, (oid1, oid2)))

            frame_ct_in_video += len(video)
            batched_video_splits.append(frame_ct_in_video)

        return batched_ids, batched_captions, batched_gt_bboxes, batched_gt_masks, \
            batched_obj_pairs, batched_object_ids, batched_video_splits, \
            torch.stack(batched_reshaped_raw_videos), batched_gt_obj_names, \
            batched_gt_object_rels, batched_gpt_specs

def open_pvsg_loader(cache_path, dataset_dir, dataset_name, batch_size, device, training_percentage=100, testing_percentage=100, max_video_len=8):

  train_dataset = OpenPVSGDataset(dataset_dir, dataset_name, cache_path=cache_path, device=device, phase="train", data_percentage = training_percentage, max_vid_len=max_video_len)
  train_loader = DataLoader(train_dataset, batch_size, collate_fn=OpenPVSGDataset.collate_fn, shuffle=False, drop_last=True)
  valid_dataset = OpenPVSGDataset(dataset_dir, dataset_name, cache_path=cache_path, device=device, phase="val", data_percentage=testing_percentage, max_vid_len=max_video_len)
  test_loader = DataLoader(valid_dataset, batch_size, collate_fn=OpenPVSGDataset.collate_fn, shuffle=False, drop_last=True)
  return (train_dataset, valid_dataset, train_loader, test_loader)