from pathlib import Path
import logging
import cv2
import torch
from torchvision.transforms import Normalize
import numpy as np
from PIL import Image

import pdb

logger = logging.getLogger(__name__)


def mask_iou(pred, target, eps=1e-7):
    """
    :param pred: [N x H x W]
    :param target: [N x H x W]
    :param eps: 1e-7 or so
    :return: iou: size [1]
    """
    assert len(
        pred.shape) == 3 and pred.shape == target.shape, f"pred shape {pred.shape} and target shape {target.shape}"

    N = pred.size(0)
    num_pixels = pred.size(-1) * pred.size(-2)
    no_obj_flag = (target.sum((1, 2)) == 0)

    temp_pred = torch.sigmoid(pred)
    pred = (temp_pred > 0.5).int()
    inter = (pred * target).sum((1, 2))
    union = torch.max(pred, target).sum((1, 2))

    inter_no_obj = ((1 - target) * (1 - pred)).sum((1, 2))
    inter[no_obj_flag] = inter_no_obj[no_obj_flag]
    union[no_obj_flag] = num_pixels

    iou = torch.sum(inter / (union + eps)) / N

    return iou


def _eval_pr(y_pred, y, num):
    """
    :param y_pred:[B, H, W]
    :param y: [B, H, W]
    :param num: 255 or so
    :return: precision, recall from threshold 0 to 1
    """
    thlist = torch.linspace(0, 1 - 1e-10, num, device=y_pred.device).reshape(1, num, 1, 1)
    y_pred, y = y_pred.unsqueeze(1), y.unsqueeze(1)  # [B,1,H,W]
    y_temp = (y_pred >= thlist).float()  # [B,num,H,W]
    tp = (y_temp * y).sum((-2, -1))  # [B,num]
    prec, recall = tp / (y_temp.sum((-2, -1)) + 1e-20), tp / (y.sum((-2, -1)) + 1e-20)

    return prec, recall


def Eval_Fmeasure(pred, gt, pr_num=255):
    """
    :param pred: [N x H x W]
    :param gt: [N x H x W]
    :param pr_num: 255 or so
    :return: iou: size [1]
    """

    pred = torch.sigmoid(pred)  # =======================================[important]
    beta2 = 0.3
    print(f"{pred.size(0)} videos in this batch")
    # examples with totally black GTs are out of consideration
    non_black_id = torch.mean(gt, dim=(-2, -1)) != 0.0
    pred, gt = pred[non_black_id], gt[non_black_id]

    prec, recall = _eval_pr(pred, gt, pr_num)  # [B,num]
    f_score = (1 + beta2) * prec * recall / (beta2 * prec + recall)
    f_score[torch.isnan(f_score)] = 0  # for Nan
    score = f_score.mean(0).max().item()
    return score

def draw_masks(img, mask):
    alpha=1
    beta=0.5
    gamma=0
    # pdb.set_trace()
    mask_image = cv2.addWeighted(img,alpha,mask,beta,gamma)
    return mask_image
def save_mask_ms3(pred_masks, save_base_path, video_name_list):
    """
    :param pred_masks: [bs*5, 1, 224, 224]
    :param save_base_path: path to save the masks
    :param video_name_list: list of video
    """
    save_base_path = Path(save_base_path)
    save_base_path.mkdir(parents=True, exist_ok=True)

    pred_masks = pred_masks.squeeze(2)
    pred_masks = (torch.sigmoid(pred_masks) > 0.5).int()

    pred_masks = pred_masks.view(-1, 5, pred_masks.shape[-2], pred_masks.shape[-1])
    pred_masks = pred_masks.cpu().data.numpy().astype(np.uint8)
    pred_masks *= 255
    bs = pred_masks.shape[0]

    for idx in range(bs):
        video_name = video_name_list[idx]
        mask_save_path = save_base_path / video_name
        mask_save_path.mkdir(parents=True, exist_ok=True)
        one_video_masks = pred_masks[idx]  # [5, 1, 224, 224]
        for video_id in range(len(one_video_masks)):
            one_mask = one_video_masks[video_id]
            output_name = f"{video_name}_{video_id}.png"
            im = Image.fromarray(one_mask).convert('P')
            im.save(str(mask_save_path / output_name), format='PNG')


def save_mask_s4(imgs, gt_masks, pred_masks, save_base_path, category_list, video_name_list):
    # pred_mask: [bs*5, 1, 224, 224]

    save_base_path = Path(save_base_path)
    save_base_path.mkdir(parents=True, exist_ok=True)

    pred_masks = pred_masks.squeeze(2)
    pred_masks = (torch.sigmoid(pred_masks) > 0.5).int()

    pred_masks = pred_masks.view(-1, 5, pred_masks.shape[-2], pred_masks.shape[-1])
    imgs = de_norm(imgs.cpu()).permute(0,2,3,1).data.numpy().astype(np.uint8)
    gt_masks = gt_masks.cpu().data.numpy().astype(np.uint8)[...,None]
    pred_masks = pred_masks.cpu().data.numpy().astype(np.uint8)
    gt_masks *= 255
    pred_masks *= 255
    bs = pred_masks.shape[0]


    for idx in range(bs):
        category, video_name = category_list[idx], video_name_list[idx]
        mask_save_path = save_base_path / category / video_name
        mask_save_path.mkdir(parents=True, exist_ok=True)
        one_video_masks = pred_masks[idx]  # [5, 1, 224, 224]
        for video_id in range(len(one_video_masks)):
            one_img = imgs[video_id]
            one_gt = gt_masks[video_id]
            one_gt = one_gt.repeat(3, axis=-1)
            one_mask = one_video_masks[video_id]
            one_mask = np.expand_dims(one_mask, axis=-1).repeat(3,axis=-1)

            gt_img = draw_masks(img=one_img, mask=one_gt)
            mask_img = draw_masks(img=one_img, mask=one_mask)

            output_name = f"{video_name}_{video_id}.png"
            # cv2.imwrite(str(mask_save_path / output_name),one_mask)

            two_img = np.concatenate((gt_img, mask_img),axis=-2)
            two_img=cv2.cvtColor(two_img,cv2.COLOR_RGB2BGR)
            output_name = f"{video_name}_{video_id}.png"
            cv2.imwrite(str(mask_save_path / output_name),two_img)
            im = Image.fromarray(two_img).convert('P')
            im.save(str(mask_save_path / output_name), format='PNG')
