
import torch
import numpy as np
import copy
import random

from detectron2.structures import Boxes, Instances
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms import ToPILImage


def draw_boxes(img, bboxes, filename):
    img = draw_bounding_boxes(img, bboxes, width=4, colors='green', fill=False)
    img = ToPILImage()(img)
    img.save(filename)


def compute_iou(v, u):
    x_left = float(max(v[0], u[0]))
    y_top = float(max(v[1], u[1]))
    x_right = float(min(v[2], u[2]))
    y_bottom = float(min(v[3], u[3]))

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    inter = (x_right - x_left) * (y_bottom - y_top)
    uni = (float(v[2] - v[0]) * float(v[3] - v[1])) + (float(u[2] - u[0]) * float(u[3] - u[1])) - inter

    return (inter + 1e-6) / (uni + 1e-6)


def check_overlap(new_bbox, ann_list, iou_threshold=0.1):

    for obj in ann_list:
        if compute_iou(obj, new_bbox) > iou_threshold:
            return False

    return True


class ObjMem():
    def __init__(self, n_categories, max_size=20):
        self.obj_memory = [ [] for cat_id in range(n_categories) ]
        self.max_size = max_size
        self.cat_freq = np.array([0 for cat_id in range(n_categories)])
        self.h_values = [ [] for cat_id in range(n_categories)]
        self.w_values = [ [] for cat_id in range(n_categories)]
        self.n_objs = []

    def add_obj(self, patch, cat_id):
        self.obj_memory[cat_id].append(copy.deepcopy(patch))
        self.obj_memory[cat_id] = self.obj_memory[cat_id][-self.max_size:]

    def get_obj(self, cat_id):
        return random.choice(self.obj_memory[cat_id])

    def process_images(self, data, step_iter=0, copies=2):

        aug_data = []

        self.n_objs += [len(img['instances'].gt_boxes.tensor) for img in data if img['ann_type'] != 'pseudo']
        self.n_objs = self.n_objs[-self.max_size:]
        
        if len(self.n_objs) > 0:
            nobjs_img = int(np.ceil(np.mean(self.n_objs)))
        else:
            nobjs_img = 5

        for img in data:

            aug_img = copy.deepcopy(img)

            img_h, img_w = aug_img['image'].shape[1], aug_img['image'].shape[2]

            bboxes = aug_img['instances'].gt_boxes.tensor
            labels  = aug_img['instances'].gt_classes

            b_size = len(bboxes)

            for i in range(b_size):

                # copy only labeled objects to object memory
                if img['ann_type'] != 'pseudo':

                    box_w, box_h = (bboxes[i][2] - bboxes[i][0]), (bboxes[i][3] - bboxes[i][1])
                    x1, y1, x2, y2 = int(torch.round(bboxes[i][0])), int(torch.round(bboxes[i][1])), \
                                        int(torch.round(bboxes[i][2])), int(torch.round(bboxes[i][3]))

                    if float(box_w) <= 1.0 or float(box_h) <= 1.0:
                        continue

                    # check overlap with other boxes
                    # if not check_overlap([x1, y1, x2, y2], [bboxes[_i] for _i in range(len(bboxes)) if _i != i], 0.2):
                    if not check_overlap([x1, y1, x2, y2], torch.cat((bboxes[:i], bboxes[i+1:])), 0.2):
                        continue

                    patch = copy.deepcopy(img['image'][:, y1:y2, x1:x2])

                    self.cat_freq[int(labels[i])] += 1

                    self.w_values[int(labels[i])].append(box_w)
                    self.h_values[int(labels[i])].append(box_h)
                    self.w_values[int(labels[i])] = self.w_values[int(labels[i])][-self.max_size:]
                    self.h_values[int(labels[i])] = self.h_values[int(labels[i])][-self.max_size:]
                    # self.cat_freq[int(labels[i])] = step_iter

                    self.add_obj(patch, labels[i])

            # paste objects from memory into the unlabeled images
            n_obj_copies = max(0, (nobjs_img)) #np.random.randint(7, 12)

            for i in range(n_obj_copies):

                valid_mask = (self.cat_freq > 0)
                valid_cats = np.where(valid_mask)[0]

                if len(valid_cats) == 0:
                    continue

                cat_id = np.random.choice(valid_cats) #, p=cat_probs)

                patch = self.get_obj(cat_id)

                box_h, box_w = patch.shape[1], patch.shape[2]

                area = box_h * box_w

                for _attempts in range(copies):

                    random_x_min = int(np.round(box_w / 2.0))
                    random_x_max = int(np.round(img_w - box_w / 2.0))

                    random_y_min = int(np.round(box_h / 2.0))
                    random_y_max = int(np.round(img_h - box_h / 2.0))

                    if random_x_min >= random_x_max or random_y_min >= random_y_max:
                        break

                    try:
                        random_x, random_y = np.random.randint(random_x_min, random_x_max),\
                                            np.random.randint(random_y_min, random_y_max)
                    except:
                        print (bbox)
                        print (box_w, img_w)
                        print (box_h, img_h)
                        print (random_x_min, random_x_max, random_y_min, random_y_max)
                        random_x, random_y = np.random.randint(random_x_min, random_x_max),\
                                            np.random.randint(random_y_min, random_y_max)

                    new_box_h, new_box_w = box_h, box_w

                    RANDOM RESIZE object augmentation
                    if np.random.rand() > 0.5:
                        if area <= 32 * 32:
                            # small object
                            scale_factor = np.random.uniform(1.1, 1.2)
                        else:
                            # large object    
                            scale_factor = np.random.uniform(0.4, 0.9)                

                        new_box_h, new_box_w = (new_box_h * scale_factor), (new_box_w * scale_factor)

                    if new_box_h <= 1.0 + 1e-5 or new_box_w <= 1.0 + 1e-5:
                        break

                    x_min, y_min = int(random_x - new_box_w / 2), int(random_y - new_box_h / 2)
                    x_max, y_max = int(x_min + new_box_w), int(y_min + new_box_h)

                    if x_min < 0 or x_max > int(img_w) or y_min < 0 or y_max > int(img_h):
                        continue

                    # check overlap with existing boxes
                    if not check_overlap([x_min, y_min, x_max, y_max], bboxes, 0.2):
                        continue

                    patch = torch.nn.functional.interpolate(torch.unsqueeze(patch, 0), size=(y_max - y_min, x_max - x_min))[0]

                    aug_img['image'][:, y_min:y_max, x_min:x_max] = patch
                    bboxes = torch.cat([bboxes, torch.tensor([[x_min, y_min, x_max, y_max]], device=('cuda' if bboxes.get_device() >= 0 else 'cpu'))], 0)
                    labels = torch.cat([labels, torch.tensor(cat_id).reshape(1)], 0)
                    break

            # if aug_img['ann_type'] == 'pseudo':
            del aug_img['instances']

            aug_img['instances'] = Instances((img_h, img_w))

            aug_img['instances'].gt_boxes = Boxes(bboxes)
            aug_img['instances'].gt_classes = labels

            aug_data.append(aug_img)

        return aug_data
