import numpy as np
import os
import cv2
from PIL import Image
from utils import *

def bitmasks2bboxes(bitmasks):
    bitmasks_array = np.stack(bitmasks)
    # boxes = np.zeros((bitmasks_array.shape[0], 4), dtype=np.float32)
    boxes = []
    x_any = np.any(bitmasks_array, axis=1)
    y_any = np.any(bitmasks_array, axis=2)
    for idx in range(bitmasks_array.shape[0]):
        x = np.where(x_any[idx, :])[0]
        y = np.where(y_any[idx, :])[0]
        box = {}
        if len(x) > 0 and len(y) > 0:
            box['x1'] = x[0]
            box['x2'] = x[-1]
            box['y1'] = y[0]
            box['y2'] = y[-1]

            # boxes[idx, :] = np.array((x[0], x[-1], y[0], y[-1]),
            #                          dtype=np.float32)
            boxes.append(box)
    return boxes

# from memory_profiler import profile

# @profile
def load_annotations(datapoint, mask_path, cates2id):
    result = {}
    pan_mask = np.array(Image.open(mask_path)).astype(
        np.int64)  # palette format saved one-channel image
    # default:int16, need to change to int64 to avoid data overflow
    objects_info = datapoint['objects']

    gt_semantic_seg = -1 * np.ones_like(pan_mask)
    classes = []
    masks = []
    instance_ids = []
    for instance_id in np.unique(pan_mask):  # 0,1...n object id
        # filter background (void) class
        if instance_id == 0:  # no segmentation area
            category = 'background'
            gt_semantic_seg[pan_mask == instance_id] = cates2id[
                category]  # 61
        else:  # gt_label & gt_masks do not include "void"
            if instance_id > len(objects_info):
                continue
            category = objects_info[instance_id - 1]['category']
            semantic_id = cates2id[category]
            gt_semantic_seg[pan_mask == instance_id] = semantic_id
            classes.append(category)
            instance_ids.append(instance_id)
            masks.append((pan_mask == instance_id).astype(np.int))

    if len(
            classes
    ) == 0:  # this image is annotated as "all background", no classes, no masks... (very few images)
        print('{} is annotated as all background!'.format(
            datapoint['data_id']))
        gt_labels = classes  # empty array
        gt_instance_ids = np.array(instance_ids).astype(np.int)
        _height, _width = pan_mask.shape
        # gt_masks = BitmapMasks(masks, height=_height, width=_width)
        gt_masks = masks
    else:
        gt_labels = classes
        gt_instance_ids = np.stack(instance_ids).astype(np.int)
        _height, _width = pan_mask.shape
        # gt_masks = BitmapMasks(masks, height=_height, width=_width)
        gt_masks = masks

        # check the sanity of gt_masks
        verify = np.sum(gt_masks, axis=0)
        # assert (verify == (pan_mask != 0).astype(
            # verify.dtype)).all()  # none-background area exactly same

    result['gt_labels'] = gt_labels
    result['gt_masks'] = gt_masks
    result['gt_instance_ids'] = gt_instance_ids  # ??
    # result['mask_fields'] = ['gt_masks']

    # generate boxes
    if len(gt_masks) == 0:
        result['gt_bboxes'] = []
    else:
        boxes = bitmasks2bboxes(gt_masks)
        result['gt_bboxes'] = boxes
    # result['bbox_fields'] = ['gt_bboxes']
    return result

def load_video(video_path, start_time, end_time):

    if not os.path.exists(video_path):
        return []
    # assert (os.path.exists(video_path))
    cap = cv2.VideoCapture(video_path)
    video = []
    current_frame = 0

    while(cap.isOpened()):

        # Capture frames in the video
        ret, frame = cap.read()

        if not current_frame >= start_time:
            current_frame += 1
            continue
        if current_frame >= end_time:
            break

        if ret == True:
            video.append(frame)
        else:
            break
        current_frame += 1

    video_window = np.stack(video)
    # video_window = np.stack(video[start_time: end_time])

    return video_window