
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import torch

def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((anns[0].shape[0], anns[0].shape[1], 4))
    img[:, :, 3] = 0
    for m in anns:
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask 
        if borders:
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

    ax.imshow(img)
    return img

def visualize_weights(anns, weights, method="max"):
    if len(anns) == 0:
        return

    img = np.zeros((anns[0].shape[0], anns[0].shape[1], 4))
    img[:, :, 0] = 1
    if method == "overlay":
        img[:, :, 3] = 1
    for m, w in zip(anns, weights):
        if method == "max":
            img[:,:,3][m] = np.maximum(img[:,:,3][m], w)
        elif method == "overlay":
            img[:,:,3][m] *= (1-w)
    if method == "overlay":
        img[:, :, 3] = 1 - img[:,:,3]
    return img

def show_img_and_mask(image, masks: np.ndarray, weights = None, divide_weights_by_area=True, save_path='./playground/tmp.png'):
    dpi = 100
    plt.figure(figsize=(image.shape[0]/dpi, image.shape[1]/dpi), dpi=dpi)
    plt.imshow(image)
    if weights is not None:
        if divide_weights_by_area:
            mask_areas = masks.sum(axis=(1, 2)) / (masks.shape[1] * masks.shape[2])
            weights = weights / mask_areas.clip(0.001, 1)
            weights = weights.clip(0, 0.75)
        mask_img = visualize_weights(masks, weights)
    else:
        mask_img = show_anns(masks)
    plt.axis('off')
    if save_path is None:
        plt.show()
    else:
        plt.close()
        if len(masks) == 0:
            cv2.imwrite(save_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
            return
        new_img = np.uint8(image * (1-mask_img[:,:,3:]) + mask_img[:,:,:3] * mask_img[:,:,3:] *255)
        cv2.imwrite(save_path, cv2.cvtColor(new_img, cv2.COLOR_RGB2BGR))

def show_img(img:torch.Tensor, save_path='./playground/tmp.png'):
    if img.ndim == 4:
        img = img.squeeze(0)
    if img.shape[-1] != 3:
        img = img.permute(1,2,0)
    if img.dtype == torch.bfloat16 or img.dtype == torch.float16:
        img = img.float()
    img = (img-img.min())/(img.max()-img.min()) * 255
    img = img.cpu().numpy().astype(np.uint8)
    img = Image.fromarray(img)
    img.save(save_path)
