#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This started as a copy of https://bitbucket.org/RSKothari/multiset_gaze/src/master/ 
with additional changes and modifications to adjust it to our implementation. 

Copyright (c) 2021 Rakshit Kothari, Aayush Chaudhary, Reynold Bailey, Jeff Pelz, 
and Gabriel Diaz
"""
import torch
import numpy as np
import torch.nn.functional as F

from einops import rearrange

from helperfunctions.utils import create_meshgrid, soft_heaviside


def get_seg_loss(gt_dict, pd_dict, alpha):
    # Custom function to iteratively go over each sample in a batch and
    # compute loss.
    # cond: Mask exist -> 1, else 0

    op = pd_dict['predict']
    target = gt_dict['mask']
    spatWts = gt_dict['spatial_weights']
    distMap = gt_dict['distance_map']

    device = op.device

    # Mask availability
    cond = gt_dict['mask_available'].to(device)

    B = op.shape[0]
    loss_seg = []

    for i in range(0, B):
        if cond[i]:
            # Valid mask exists
            l_sl = SurfaceLoss(op[i, ...].unsqueeze(0),
                               distMap[i, ...].unsqueeze(0))
            l_cE = wCE(op[i, ...],
                       target[i, ...],
                       spatWts[i, ...])
            l_gD = GDiceLoss(op[i, ...].unsqueeze(0),
                             target[i, ...].unsqueeze(0),
                             F.softmax)
            loss_seg.append(alpha*l_sl + (1-alpha)*l_gD + l_cE)

    if len(loss_seg) > 0:
        total_loss = torch.sum(torch.stack(loss_seg))/torch.sum(cond)

        assert total_loss.dim() == 0, 'Segmentation losses must be a scalar'
        return total_loss
    else:
        return torch.tensor([0.0]).to(device)

def get_com(seg_map, temperature=4):
    '''
    获取质心以获得检测到的瞳孔或虹膜中心。质心在规范化空间中，范围在 -1 到 1 之间。

    参数:
    seg_map : torch.Tensor
        瞳孔或虹膜预测的单通道张量，形状为 BXHXW，其中 B 是批次大小，H 和 W 是高度和宽度。
    temperature : int, optional
        温度参数，默认为 4。

    返回:
    predPts : torch.Tensor
        包含质心坐标的张量，形状为 (B*F, 2)，其中 B 是批次大小，F 是帧数。
    '''
    device = seg_map.device

    # 将批次大小和帧数的维度合并
    seg_map = rearrange(seg_map, 'b f h w -> (b f) h w')

    B, H, W = seg_map.shape
    wtMap = F.softmax(rearrange(seg_map, 'b h w -> b (h w)') * temperature, dim=1)  # [B, HXW]

    XYgrid = create_meshgrid(H, W, normalized_coordinates=True)  # 1xHxWx2
    xloc = XYgrid[0, :, :, 0].reshape(-1).to(device)
    yloc = XYgrid[0, :, :, 1].reshape(-1).to(device)

    xpos = torch.sum(wtMap * xloc, -1, keepdim=True)
    ypos = torch.sum(wtMap * yloc, -1, keepdim=True)
    predPts = torch.stack([xpos, ypos], dim=1).squeeze()

    # 返回的张量维度为 [BXF, HXW]
    return predPts



def get_uncertain_l1_loss(ip_vector,
                          target_vector,
                          weights_per_channel,
                          uncertain,
                          cond,
                          do_aleatoric=False):
    ''' 计算有效样本上的 L1 或 aleatoric L1 距离 '''

    if weights_per_channel is None:
        weights_per_channel = torch.ones(ip_vector.shape[1],
                                         dtype=ip_vector.dtype).to(ip_vector.device)
    else:
        weights_per_channel = torch.tensor(weights_per_channel).to(ip_vector.dtype).to(ip_vector.device)

    if torch.any(cond):
        loss_per_sample = F.l1_loss(ip_vector, target_vector, reduction='none')

        if do_aleatoric:
            # 如果存在不确定估计，则使用 aleatoric 公式
            loss_per_sample = .1 * uncertain + \
                loss_per_sample / torch.exp(uncertain)

        # 在维度上求和，并在样本间加权平均
        loss_per_sample = loss_per_sample * weights_per_channel
        loss_per_sample = torch.sum(loss_per_sample, dim=1) * cond
        total_loss = torch.sum(loss_per_sample) / torch.sum(cond)

        assert total_loss.dim() == 0, 'L1 损失必须是标量'
        return total_loss
    else:
        # 没有找到有效样本
        return torch.tensor([0.0]).to(ip_vector.device)


def SurfaceLoss(x, distmap):
    '''
    表面损失。对于没有地面实况的类别，distmap 理想情况下应填充为 0。

    参数:
    x : torch.Tensor
        分割网络输出的张量，形状为 (B, C, H, W)，其中 B 是批次大小，C 是通道数，H 和 W 是高度和宽度。
    distmap : torch.Tensor
        距离图的张量，形状与 x 相同。

    返回:
    score : torch.Tensor
        损失分数，形状为 (B,)。
    '''
    x = torch.softmax(x, dim=1)
    score = x.flatten(start_dim=2) * distmap.flatten(start_dim=2)
    score = torch.mean(score, dim=2)  # 每个通道的像素平均值
    score = torch.mean(score, dim=1)  # 通道间的平均值
    return score



def GDiceLoss(ip, target, norm=F.softmax):
    '''
    Generalized Dice Loss.

    参数:
    ip : torch.Tensor
        模型的输出张量，形状为 (B, C, H, W)，其中 B 是批次大小，C 是通道数，H 和 W 是高度和宽度。
    target : torch.Tensor
        目标张量，形状为 (B, H, W)。
    norm : function, optional
        规范化函数，默认为 F.softmax。

    返回:
    torch.Tensor
        损失值。
    '''
    mxLabel = ip.shape[1]
    allClasses = np.arange(mxLabel, )
    labelsPresent = np.unique(target.cpu().numpy())

    device = target.device

    Label = (np.arange(mxLabel) == target.cpu().numpy()[..., None]).astype(np.uint8)
    Label = np.moveaxis(Label, 3, 1)
    target = torch.from_numpy(Label).to(device).to(ip.dtype)

    loc_rm = np.where(~np.in1d(allClasses, labelsPresent))[0]

    assert ip.shape == target.shape
    ip = norm(ip, dim=1)  # 在通道上进行 Softmax 或 Sigmoid
    ip = torch.flatten(ip, start_dim=2, end_dim=-1)
    target = torch.flatten(target, start_dim=2, end_dim=-1).to(device).to(ip.dtype)

    numerator = ip * target
    denominator = ip + target

    # 对于在目标中不存在但存在于输入中的类别，设置权重为0
    class_weights = 1. / (torch.sum(target, dim=2) ** 2).clamp(1e-5)
    if loc_rm.size > 0:
        for i in np.nditer(loc_rm):
            class_weights[:, i.item()] = 0
    A = class_weights * torch.sum(numerator, dim=2)
    B = class_weights * torch.sum(denominator, dim=2)
    dice_metric = 2. * torch.sum(A, dim=1) / torch.sum(B, dim=1)
    return torch.mean(1 - dice_metric.clamp(1e-5))


def wCE(ip, target, spatWts):
    '''
    Weighted Cross Entropy Loss.
    参数:
    ip : torch.Tensor
        模型的输出张量，形状为 (B, C, H, W)，其中 B 是批次大小，C 是通道数，H 和 W 是高度和宽度。
    target : torch.Tensor
        目标张量，形状为 (B, H, W)。
    spatWts : torch.Tensor
        空间权重张量，形状为 (B, C)。
    返回:
    torch.Tensor
        损失值。
    '''
    mxLabel = ip.shape[0]
    allClasses = np.arange(mxLabel, )
    labelsPresent = np.unique(target.cpu().numpy())
    rmIdx = allClasses[~np.in1d(allClasses, labelsPresent)]
    if rmIdx.size > 0:
        loss = spatWts.view(1, -1) * F.cross_entropy(ip.view(1, mxLabel, -1),
                                                      target.view(1, -1),
                                                      ignore_index=rmIdx.item())
    else:
        loss = spatWts.view(1, -1) * F.cross_entropy(ip.view(1, mxLabel, -1),
                                                      target.long().view(1, -1))
    loss = torch.mean(loss)
    return loss



def get_mask(mesh, opEl):
    '''
    根据椭圆参数获取遮罩。
    参数:
    mesh : torch.Tensor
        网格点坐标张量，形状为 (..., 2)。
    opEl : torch.Tensor
        椭圆参数张量，包括中心坐标和长短轴长度及旋转角度，形状为 (4,)。
    返回:
    torch.Tensor, torch.Tensor
        正遮罩和负遮罩。
    '''
    # posmask: 椭圆外部为正
    # negmask: 椭圆内部为正
    X = (mesh[..., 0] - opEl[0]) * torch.cos(opEl[-1]) + \
        (mesh[..., 1] - opEl[1]) * torch.sin(opEl[-1])
    Y = -(mesh[..., 0] - opEl[0]) * torch.sin(opEl[-1]) + \
        (mesh[..., 1] - opEl[1]) * torch.cos(opEl[-1])
    posmask = (X / opEl[2]) ** 2 + (Y / opEl[3]) ** 2 - 1
    negmask = 1 - (X / opEl[2]) ** 2 - (Y / opEl[3]) ** 2

    posmask = soft_heaviside(posmask, sc=64, mode=3)
    negmask = soft_heaviside(negmask, sc=64, mode=3)
    return posmask, negmask

