# from memory_profiler import profile

import os 
import json
import random
import math
import torch
from torch.utils.data import DataLoader
import copy

from utils import *
from loading import *

class OpenPVSGDataset():

    def __init__(self, dataset_dir, dataset_filename, device, data_percentage, cache_path, 
                 phase="train", max_vid_len=10, max_obj_per_frame=8, neg_example_ct = 5) -> None:
        
        dataset_path = os.path.join(dataset_dir, dataset_filename)
        raw_dataset = json.load(open(dataset_path, 'r'))
        
        gpt_cache = json.load(open(cache_path, 'r'))
        dataset = []
        
        self.video_dirs = {}
        self.mask_dirs = {}
        
        self.video_dirs['vidor'] = os.path.join(dataset_dir, "VidOR/videos")
        self.mask_dirs['vidor'] = os.path.join(dataset_dir, "VidOR/masks")
        self.video_dirs['epic_kitchen'] = os.path.join(dataset_dir, "EpicKitchen/videos")
        self.mask_dirs['epic_kitchen'] = os.path.join(dataset_dir, "VidOR/masks")
        self.video_dirs['ego4d'] = os.path.join(dataset_dir, "Ego4D/videos")
        self.mask_dirs['ego4d'] = os.path.join(dataset_dir, "Ego4D/masks")
        negative_examples = json.load(open(os.path.join(dataset_dir, "neg_examples.json"), 'r'))
        
        data_lookup = {dp['video_id']: dp for dp in raw_dataset['data']}
        
        self.THING_CLASSES = raw_dataset['objects']['thing']  # 115
        self.STUFF_CLASSES = raw_dataset['objects']['stuff']  # 11
        self.BACKGROUND_CLASSES = ['background']
        self.CLASSES = self.THING_CLASSES + self.STUFF_CLASSES
        self.num_thing_classes = len(self.THING_CLASSES)
        self.num_stuff_classes = len(self.STUFF_CLASSES)
        self.num_classes = len(self.CLASSES)  # 126
        self.cates2id = dict(
            zip(self.CLASSES + self.BACKGROUND_CLASSES,
                range(len(self.CLASSES + self.BACKGROUND_CLASSES))))
        
        self.neg_example_ct = neg_example_ct
        
        data_split_info = raw_dataset['split']
        
        for dataset_name, data_split in data_split_info.items():
            # dataset.append(data_split)
            for data_id in data_split[phase]:    
                # if data_id in checked_datapoints:
                #     continue
                # if not "1006_4580824633" == data_id:
                #     continue
                
                for caption in data_lookup[data_id]['captions']:
                    
                    clean_des = clean_cap(caption['description'])
                    if not clean_des in gpt_cache:
                        continue
                    
                    video_path = os.path.join(self.video_dirs[dataset_name], f"{data_id}.mp4")
                    if not os.path.exists(video_path):
                        continue
                    
                    datapoint = {'data_id': data_id, 
                                 'caption': caption, 
                                 'gpt_spec': gpt_cache[clean_des],
                                 'dataset': dataset_name, 
                                 'objects': data_lookup[data_id]['objects'], 
                                 'meta': data_lookup[data_id]['meta'], 
                                 'relations': data_lookup[data_id]['relations'],
                                 'neg_example': negative_examples[clean_des]}
                    
                    start, end = get_start_end(caption=caption)
                    if not start < end:
                        continue
                    dataset.append(datapoint)
    
        # Shuffle the dataset so that cutting the dataset will still give an indistribution dataset
        random.shuffle(dataset)
        
        # start = False
        # new_dataset = []
        # for dp in dataset:
        #     if data_id == '1025_6061530960':
        #         start = True
        #     if start: 
        #         new_dataset.append(dp)
        # dataset = new_dataset

        dp_count = math.floor(data_percentage / 100 * len(dataset))
        self.dataset = dataset[:dp_count]
        # self.dataset = dataset[624:]
         
        self.device = device
        self.max_vid_len = max_vid_len
        self.max_obj_per_frame = max_obj_per_frame

    def  __len__(self):
        return len(self.dataset)

    def process_val(self, x, max_val):
        x = max(0, x)
        x = min(x, max_val)
        return x
    
    # @profile
    def __getitem__(self, i):
        all_ids = list(range(self.__len__()))
        all_ids.remove(i)
        neg_spec_i = random.choice(all_ids)
        
        datapoint = copy.deepcopy(self.dataset[i])
        vid_id = datapoint['data_id']
        dataset = datapoint['dataset']
        sampled_neg_example = {}
        sampled_neg_example['neg_entity'] = random.sample(datapoint['neg_example']['neg_entity'], k = min(self.neg_example_ct, len(datapoint['neg_example']['neg_entity'])))
        sampled_neg_example['neg_binary'] = random.sample(datapoint['neg_example']['neg_binary'], k = min(self.neg_example_ct, len(datapoint['neg_example']['neg_binary'])))
        datapoint['neg_example'] = sampled_neg_example
        caption = datapoint['caption']
        
        video_path = os.path.join(self.video_dirs[dataset], f"{vid_id}.mp4")
        mask_dir = os.path.join(self.mask_dirs[dataset], vid_id)
        
        # Load video and caption
        start, end = get_start_end(caption=caption)
        assert start < end
        video = load_video(video_path, start, end)
        start = max(start, 0)
        end = min(end, start + video.shape[0])
        mask_paths = get_mask_paths(mask_dir=mask_dir, start_time=start, end_time=end)
        
        datapoint['start_time'] = start
        datapoint['end_time'] = end
        
        datapoint['neg_gpt_spec'] = self.dataset[neg_spec_i]['gpt_spec']
        
        # Load masks
        masks = []
        for mask_path in mask_paths:
            result = load_annotations(datapoint, mask_path, self.cates2id)
            masks.append(result)
        datapoint['masks'] = masks
        
        new_relations = {i: [] for i in range(start, end)}
        for sub_id, obj_id, rel, time_ls in datapoint['relations']:
            for from_t, to_t in time_ls:
                lap_start, lap_end = get_overlap((from_t, to_t), (start, end))
                if not lap_start == -1:
                    for i in range(lap_start, lap_end):
                        new_relations[i].append((sub_id, obj_id, rel))
        datapoint['relations'] = list(new_relations.values())
        
        # Sample the video if too large
        if len(video) > self.max_vid_len:
            sample_rate = math.ceil(len(video) / self.max_vid_len)
            video = [f for i, f in enumerate(video) if i % sample_rate == 0]
            new_bboxes = [b for i, b in enumerate(datapoint['masks']) if i % sample_rate == 0]
            new_relations = [r_ls for i, r_ls in new_relations.items() if i % sample_rate == 0]
            datapoint['masks'] = new_bboxes
            datapoint['relations'] = new_relations
            
        # Normalize video color and shape
        reshaped_video = []
        norm_reshaped_video = []
        v_height = video[0].shape[0]
        v_width = video[0].shape[1]
        
        x_portion, y_portion = norm_x / v_width, norm_y / v_height
        
        new_masks = []
        
        for bboxes in datapoint['masks']: 
            clean_labels = {}
            bboxes_sizes = []
            clean_labels['gt_labels'] = []
            clean_labels['gt_masks'] = []
            clean_labels['gt_instance_ids'] = []
            clean_labels['gt_bboxes'] = []
            
            for lb, mask, id, bbox2d in zip(bboxes['gt_labels'], bboxes['gt_masks'], bboxes['gt_instance_ids'], bboxes['gt_bboxes']):
                x1, x2, y1, y2 = bbox2d['x1'], bbox2d['x2'], bbox2d['y1'], bbox2d['y2']
                
                x1 = self.process_val(x1, v_width)
                x2 = self.process_val(x2, v_width)
                y1 = self.process_val(y1, v_height)
                y2 = self.process_val(y2, v_height)

                bbox2d['x1'], bbox2d['x2'], bbox2d['y1'], bbox2d['y2'] = int(x1 * x_portion), int(x2 * x_portion), int(y1 * y_portion), int(y2 * y_portion)
                assert bbox2d['x1'] <= norm_x
                assert bbox2d['x2'] <= norm_x
                assert bbox2d['y1'] <= norm_y
                assert bbox2d['y2'] <= norm_y
                
                if (bbox2d['y2'] > bbox2d['y1'] + 5 and bbox2d['x2'] > bbox2d['x1'] + 5):
                    size = (bbox2d['y2']  - bbox2d['y1']) * (bbox2d['x2'] - bbox2d['x1'])
                    bboxes_sizes.append((size, bbox2d, lb, mask, id))
                    
                    # clean_labels['gt_bboxes'].append(bbox2d)
                    # clean_labels['gt_labels'].append(lb)
                    # clean_labels['gt_masks'].append(mask)
                    # clean_labels['gt_instance_ids'].append(id)
                # else:
                    # print('invalid gt bbox')
            sorted_bboxes = sorted(bboxes_sizes, key = lambda x: -x[0])
            sorted_bboxes = sorted_bboxes[:self.max_obj_per_frame]
            
            for size, bbox2d, lb, mask, id in sorted_bboxes:
                clean_labels['gt_bboxes'].append(bbox2d)
                clean_labels['gt_labels'].append(lb)
                clean_labels['gt_masks'].append(mask)
                clean_labels['gt_instance_ids'].append(id)
                
            new_masks.append(clean_labels)
        datapoint['masks'] = new_masks
            
        for frame in video:
            new_frame = cv2.resize(frame, (norm_x, norm_y))
            reshaped_video.append(torch.tensor(new_frame, dtype=torch.float32).to(self.device))
            shape = new_frame.shape
            new_frame = np.moveaxis(new_frame, -1, 0)
            new_frame = transform(torch.tensor(new_frame, dtype=torch.float32))
            norm_reshaped_video.append(new_frame.to(self.device))

        # Random sample a negative spec 
        return datapoint, reshaped_video, norm_reshaped_video

    @staticmethod
    def collate_fn(batch):
        batched_videos = []
        batched_reshaped_raw_videos = []
        batched_captions = []
        
        batched_obj_pairs = []
        batched_ids = []
        batched_video_splits = []
        batched_gpt_specs = []
        batched_neg_gpt_specs = []
        batched_neg_examples = []
        
        batched_gt_bboxes = []
        batched_gt_masks = []
        batched_gt_obj_names = []
        batched_gt_object_rels = []
        
        batched_object_ids = []

        frame_ct_in_video = 0
        for data_id, (datapoint, reshaped_raw_video, video) in enumerate(batch):

            batched_reshaped_raw_videos += reshaped_raw_video
            batched_videos += (video)
            batched_ids.append(datapoint['data_id'])
            batched_captions.append(datapoint['gpt_spec']['caption'])
            # batched_gt_masks.append(datapoint['masks']['gt_masks'])
            # batched_gt_objects.append(datapoint['objects'])
            bounding_box_info = datapoint['masks']
            
            batched_gpt_specs.append(datapoint['gpt_spec'])
            batched_neg_gpt_specs.append(datapoint['neg_gpt_spec'])
            batched_neg_examples.append(datapoint['neg_example'])
            
            batched_gt_object_rels.append(datapoint['relations'])
                 
            all_obj_ids = set()
            for frame_id, frame in enumerate(bounding_box_info):
                for label in frame['gt_instance_ids']:
                    all_obj_ids.add(label)

            for frame_id, frame in enumerate(bounding_box_info):
                object_ct_in_frame = len(frame['gt_instance_ids'])
                obj_ids_in_frame = []

                batched_gt_bboxes += frame['gt_bboxes']
                batched_gt_masks += frame['gt_masks']
                batched_gt_obj_names += [(data_id, frame_id, l) for l in frame['gt_labels']] 
                batched_object_ids += [(data_id, frame_id, id) for id in frame['gt_instance_ids']] 
                obj_ids_in_frame = frame['gt_instance_ids']

                for oid1 in all_obj_ids:
                    for oid2 in all_obj_ids:
                        if oid1 in obj_ids_in_frame and oid2 in obj_ids_in_frame and not oid1 == oid2:
                            batched_obj_pairs.append((data_id, frame_id, (oid1, oid2)))

            frame_ct_in_video += len(video)
            batched_video_splits.append(frame_ct_in_video)

        return batched_ids, batched_captions, batched_gt_bboxes, batched_gt_masks, \
            batched_obj_pairs, batched_object_ids, batched_video_splits, \
            torch.stack(batched_reshaped_raw_videos), batched_gt_obj_names, \
            batched_gt_object_rels, batched_gpt_specs, batched_neg_gpt_specs, \
            batched_neg_examples

        # return batched_ids, torch.stack(batched_videos), \
        #     torch.stack(batched_reshaped_raw_videos), \
        #     batched_gt_bboxes, batched_gt_masks, \
        #     batched_object_ids, batched_obj_pairs, \
        #     batched_gt_obj_names, batched_gt_object_rels,\
        #     batched_gpt_specs, batched_video_splits

def open_pvsg_loader(cache_path, dataset_dir, dataset_name, batch_size, device, training_percentage=100, testing_percentage=100, max_video_len=8):
  
  train_dataset = OpenPVSGDataset(dataset_dir, dataset_name, cache_path=cache_path, device=device, phase="train", data_percentage = training_percentage, max_vid_len=max_video_len)
  train_loader = DataLoader(train_dataset, batch_size, collate_fn=OpenPVSGDataset.collate_fn, shuffle=False, drop_last=True)
  valid_dataset = OpenPVSGDataset(dataset_dir, dataset_name, cache_path=cache_path, device=device, phase="val", data_percentage=testing_percentage, max_vid_len=max_video_len)
  test_loader = DataLoader(valid_dataset, batch_size, collate_fn=OpenPVSGDataset.collate_fn, shuffle=False, drop_last=True)
  return (train_dataset, valid_dataset, train_loader, test_loader)
