import cv2
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch

def increase_brightness(img, alpha=0.2):
    height, width, _ = img.shape
    white_img = np.zeros([height,width,3],dtype=np.uint8)
    white_img.fill(255) # or img[:] = 255
    
    dst = cv2.addWeighted(img, alpha , white_img, 1-alpha, 0) 
    return dst

def increase_brightness_except(img, bbox_ls, alpha=0.2):
    height, width, _ = img.shape
    white_img = np.zeros([height,width,3],dtype=np.uint8)
    white_img.fill(255) # or img[:] = 255
    
    output_img = cv2.addWeighted(img, alpha , white_img, 1-alpha, 0) 
    
    for x1, y1, x2, y2 in bbox_ls:
        output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
    return output_img
    
def increase_brightness_draw_outer_edge(img, bbox_ls, alpha=0.2, colormap_name='Set1', thickness=2):
    img = img.cpu().numpy().astype(np.uint8)
    height, width, _ = img.shape
    white_img = np.zeros([height,width,3],dtype=np.uint8)
    white_img.fill(255) # or img[:] = 255
    
    output_img = cv2.addWeighted(img, alpha , white_img, 1-alpha, 0) 
    colormap = plt.colormaps[colormap_name]
    
    for bbox_id, (x1, y1, x2, y2) in enumerate(bbox_ls):
        output_img[y1:y2, x1:x2] = img[y1:y2, x1:x2]
        color =  [c * 255 for c in mpl.colors.to_rgb(colormap(bbox_id))]
        # print(f"color: {color}")
        output_img = cv2.rectangle(output_img, (x1, y1), (x2, y2), color, thickness) 
        
    return torch.tensor(output_img, dtype=torch.float32)


def gen_mask(height, width, bbox_ls, background=0.3, from_obj=1.0, to_obj=0.55):
    orig_mask = np.zeros([1, height,width], dtype=np.float32)
    orig_mask.fill(background)
    
    # unary
    if len(bbox_ls) == 1: 
        x1, y1, x2, y2 = bbox_ls[0]
        orig_mask[:, y1:y2, x1:x2] = from_obj
    
    # Binary
    if len(bbox_ls) == 2:
        fx1, fy1, fx2, fy2 = bbox_ls[0]
        tx1, ty1, tx2, ty2 = bbox_ls[1]
        orig_mask[:, fy1:fy2, fx1:fx2] = from_obj
        orig_mask[:, ty1:ty2, tx1:tx2] = to_obj
    
    assert len(bbox_ls) < 3
    return orig_mask

def mask_image(img, bbox_ls, background=0.3, from_obj=1.0, to_obj=0.8, color=0):
    
    channel, height, width = img.shape
    device = img.device
    mask = gen_mask(height, width, bbox_ls, background, from_obj, to_obj)
    img = img.cpu().numpy().astype(np.uint8)
    white_img = np.zeros([channel, height, width],dtype=np.uint8)
    white_img.fill(color)

    vis_mask = (1 - mask) * white_img
    vis_img = mask * img
    result_img = vis_mask + vis_img
    
    return torch.tensor(result_img, dtype=torch.float32).to(device)