import torch
import torch.nn as nn
import torchvision.transforms.functional as F

import facer

from torch.nn.modules.loss import CrossEntropyLoss

class DiceLoss(nn.Module):
    def __init__(self, n_classes):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes

    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i  # * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob.unsqueeze(1))
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-5
        intersect = torch.sum(score * target,dim=[1,2])
        y_sum = torch.sum(target * target,dim=[1,2])
        z_sum = torch.sum(score * score,dim=[1,2])
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

    def forward(self, inputs, target, weight=None, softmax=False):
        if softmax:
            inputs = torch.softmax(inputs, dim=1)
        target = self._one_hot_encoder(target)
        if weight is None:
            weight = [1] * self.n_classes
        assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size())
        # class_wise_dice = []
        loss = 0.0
        for i in range(0, self.n_classes):
            dice = self._dice_loss(inputs[:, i], target[:, i])
            # class_wise_dice.append(1.0 - dice.item())
            loss += dice * weight[i]
        return loss / self.n_classes


class SegLoss(nn.Module):
    """
    基于预训练segface模型的分割损失类
    用于约束生成网络G的训练
    """
    
    def __init__(self, model_path, device='cuda', num_classes=19):
        """
        Args:
            model_path: 预训练segface模型权重路径
            device: 设备
            num_classes: 分割类别数
        """
        super(SegLoss, self).__init__()
        
        self.device = device
        self.num_classes = num_classes
        
        # 加载预训练的segface模型
        self.seg_model = facer.face_parser(
            "farl/celebm/448",
            self.device,
            model_path=model_path,
        )
        self.seg_model.requires_grad_(False)
        self.seg_model.eval()  # 设置为评估模式，不更新参数
        self.seg_model.to(device)
        
        
        
            
        # 定义损失函数
        self.seg_ce_loss = CrossEntropyLoss(reduction='none')
        self.dice_loss = DiceLoss(num_classes)    
        
    def forward(self, generated_face, seg_labels):
        """

        """
        batch_size = generated_face.shape[0]
        
        # 确保输入图像尺寸为512x512
        assert generated_face.shape[-1] == 512 and generated_face.shape[-2] == 512
        assert seg_labels.shape[-1] == 512 and seg_labels.shape[-2] == 512

        # 使用segface模型预测分割
        # with torch.no_grad():  # 不计算梯度，因为segface模型是固定的
        seg_output,_ = self.seg_model.net(generated_face)
        
        loss_seg_ce = self.seg_ce_loss(seg_output, seg_labels.to(dtype=torch.long)).mean(dim=[1,2])
        loss_seg_dice = self.dice_loss(seg_output, seg_labels, softmax=True)

        loss = 0.5*loss_seg_ce + 0.5*loss_seg_dice

        
        return loss
    
# 使用示例
if __name__ == "__main__":
    # 初始化SegLoss
    seg_loss = SegLoss(
        model_path="ckpts/face_parsing.farl.celebm.main_ema_181500_jit.pt",
        device='cuda',
        num_classes=19,
    )
    
    import cv2
    from PIL import Image
    # 模拟G网络生成的图像和真实掩码
    image = Image.open("MM-Celeba-HQ/test_data/face/29999.jpg")
    print(image)
    
    label = Image.open("MM-Celeba-HQ/test_data/mask/29999.png")

    import torchvision
    transforms = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
            ])
    generated_image = transforms(image).unsqueeze(0).cuda()
    seg_label = (transforms(label) *255).cuda()
    print(transforms(label) *255)

    
    # 计算损失
    loss = seg_loss(generated_image, seg_label)
    print(f"Segmentation Loss: {loss.item():.4f}") 