import math

import torch
import torch.nn.functional as F

import numpy as np

from utils import load_nlcd_stats
from models.KAUNet.lib.functions.subtraction2_refpad import subtraction2_refpad
from models.KAUNet.lib.functions.dotproduction2_refpad import dotproduction2_refpad
from models.KAUNet.lib.functions.subtraction2_zeropad import subtraction2_zeropad

def multihead_edge_loss():
    def one_class_loss(i_class, y_true, edge, area, out):
        e = edge[i_class - 1]  # bs * 1 * H * W
        a = area[i_class - 1]
        o = out[i_class - 1]

        B, _, H, W = o.shape

        y_true = (y_true == i_class).float().unsqueeze(dim=1)  # bs * 1 * H * W
        edge_true = subtraction2_refpad(y_true, y_true, kernel_size=3, padding=1).abs().sum(dim=2) != 0

        e = e.view(B, -1)
        edge_true = edge_true.float().view(B, -1)

        edge_loss = F.binary_cross_entropy_with_logits(e, edge_true, reduction="mean")
        area_loss = F.binary_cross_entropy_with_logits(a, y_true, reduction="mean")
        pred_loss = F.binary_cross_entropy_with_logits(o, y_true, reduction="mean")

        return edge_loss * 0.2 + area_loss * 0.1 + pred_loss * 1
        

    def loss(outputs, y_true):
        out, outs, edges, areas = outputs

        l = 0
        for i in range(1,5):
            l += one_class_loss(i, y_true, edges, areas, outs)

        return l
    
    return loss
        
        

def multihead_loss():
    def loss(outputs, y_true, writer, epoch, save_tensorboard):
        out, out1, out2, out3, out4 = outputs
        l = 0
        l += F.binary_cross_entropy_with_logits(out1, (y_true == 1).float())
        l += F.binary_cross_entropy_with_logits(out2, (y_true == 2).float())
        l += F.binary_cross_entropy_with_logits(out3, (y_true == 3).float())
        l += F.binary_cross_entropy_with_logits(out4, (y_true == 4).float())

        if save_tensorboard:
            B, H, W = out1.shape
            o1 = torch.sigmoid(out1).view(B, 1, H, W, 1).repeat(1, 1, 1, 1, 3)  # B*1*H*W*3
            o2 = torch.sigmoid(out2).view(B, 1, H, W, 1).repeat(1, 1, 1, 1, 3)  # B*1*H*W*3
            o3 = torch.sigmoid(out3).view(B, 1, H, W, 1).repeat(1, 1, 1, 1, 3)  # B*1*H*W*3
            o4 = torch.sigmoid(out4).view(B, 1, H, W, 1).repeat(1, 1, 1, 1, 3)  # B*1*H*W*3
            o = torch.cat((o1, o2, o3, o4), dim=1).permute(0, 2, 1, 3, 4).reshape(B*H, 4*W, 3)
            writer.add_image(f"validation-multihead-prediction", o, epoch + 1, dataformats='HWC')

        return l

    return loss


def feature_triplet_center_loss(margin=0.1):
    def distance(a, b):
        return 0.5 * torch.pow(a - b, 2).sum(dim=-1)

    def triplet_center_loss(feature, y):
        # feature: B * Cin * H * W
        # y: B * N_classes * H * W, is one_hot and float

        B, C, H, W = feature.shape
        m = margin * C
        cs = []
        fs = []

        for i in range(1, 5):
            idx = (y[:, i:i+1, :, :] == 1).repeat(1, C, 1, 1).transpose(0, 1)  # Cin * B * H * W
            f = feature.transpose(0, 1)[idx].view(C, -1).transpose(0, 1)  # X * Cin
            if f.numel() > 0:
                fs.append(f)
                cs.append(f.mean(dim=0))

        # only one class exists
        if len(fs) < 2:
            return 0

        l = 0
        for i in range(len(fs)):
            f = fs[i]  # X * Cin
            ds = [distance(f, c) for c in cs]
            d_p = ds.pop(i)  # X
            d_n = ds.pop(0)
            for d in ds:
                d_n = torch.minimum(d_n, d)
            l += (d_p + m - d_n).clamp(min=0).mean()

        return l / C

    def loss(outputs, y_aps, writer, epoch, save_tensorboard):
        [x, e1, e2, e3, e4, e5, d5, d4, d3, d2, out] = outputs
        [y_ap1, y_ap2, y_ap4, y_ap8, y_ap16] = y_aps

        l = 0
        # l += triplet_center_loss(e1, y_ap1)
        l += triplet_center_loss(d2, y_ap1) * 1
        l += triplet_center_loss(out, y_ap1) * 1

        # l += triplet_center_loss(e2, y_ap2) * 0.2
        # l += triplet_center_loss(d3, y_ap2) * 1

        # l += triplet_center_loss(e3, y_ap4) * 0.2
        # l += triplet_center_loss(d4, y_ap4) * 0.8

        # l += triplet_center_loss(e4, y_ap8) * 0.2
        # l += triplet_center_loss(d5, y_ap8) * 0.4

        # l += triplet_center_loss(e5, y_ap16) * 0.4

        # tensorboard
        # if save_tensorboard:
        #     mat = e3.permute(0,2,3,1).reshape(-1, 256)
        #     label_names = ["None", "Water", "Forest", "Field", "Others"]
        #     metadata = [label_names[i] for i in y_ap4.argmax(dim=1).view(-1)]
        #     label_img = torch.nn.functional.unfold(x, kernel_size=4, stride=4, padding=0).view(-1, 4, 4, 4, 60*60)[:,:3,:,:,:].permute(0,4,1,2,3).reshape(-1, 3, 4, 4)
        #     writer.add_embedding(mat, metadata=metadata, label_img=label_img, global_step=epoch+1, tag='e3', metadata_header=None)

        return l

    return loss


def feature_supervised_loss():

    def feature_loss(feature, anchor, y_true):
        # feature B * Cin * H * W
        # anchor N_classes * Cin * H * W
        # y_true B * N_classes * H * W, one hot tensor

        def dot_sum(a, b):
            return (a * b).sum(dim=-1).mean()

        # scale feature and anchor to unit vectors
        feature_u = feature / (feature.norm(p=2, dim=1, keepdim=True) + 1e-6)
        anchor_u = anchor / (anchor.norm(p=2, dim=1, keepdim=True) + 1e-6)

        B, Cin, H, W = feature.shape
        anchor_centers = []

        l_between_class = 0
        l_within_class = 0

        for i in range(1, 5):
            # compute l_between_class
            anchor_center = anchor_u[i:i+1, :, :, :].mean(dim=(-1, -2))  # 1 * Cin
            for a_c in anchor_centers:
                l_between_class += (1 + dot_sum(a_c, anchor_center))
            anchor_centers.append(anchor_center)

            # compute l_within_class
            idx = (y_true[:, i:i+1, :, :] == 1).repeat(1, Cin, 1, 1).transpose(0, 1)  # Cin * B * H * W
            feature_inclass = feature_u.transpose(0, 1)[idx].view(Cin, -1).transpose(0, 1)  # X * Cin
            X = feature_inclass.shape[0]
            if X > 0:
                l_within_class += dot_sum(feature_inclass, anchor_center.repeat(X, 1))

            # print("Within class: ", l_within_class)
            # print("Between class: ", l_between_class)

        l_between_class /= 10.0
        l_within_class /= 5.0

        return l_between_class + l_within_class

    def loss(outputs, a_outputs, y_true, writer, epoch, save_tensorboard):
        # y_true: B*H*W, LongTensor
        [anchors, e1_a, e2_a, e3_a, e4_a, e5_a, d5_a, d4_a, d3_a, d2_a, out_a] = a_outputs
        [x, e1, e2, e3, e4, e5, d5, d4, d3, d2, out] = outputs

        l_ce = F.cross_entropy(out, y_true)

        y_true = F.one_hot(y_true, num_classes=5).permute(0, 3, 1, 2).float()  # B * N_classes * H * W
        l_fs = 0

        # l_fs += feature_loss(e1, e1_a, y_true)
        # l_fs += feature_loss(d2, d2_a, y_true)

        y_true = argmax_pooling(y_true)  # B * N_classes * H/2 * W/2
        # l_fs += feature_loss(e2, e2_a, y_true)
        # l_fs += feature_loss(d3, d3_a, y_true)

        y_true = argmax_pooling(y_true)  # B * N_classes * H/4 * W/4
        l_fs += feature_loss(e3, e3_a, y_true)
        l_fs += feature_loss(d4, d4_a, y_true)

        # tensorboard
        if save_tensorboard:
            mat = e3.permute(0, 2, 3, 1).reshape(-1, 256)
            label_names = ["None", "Water", "Forest", "Field", "Others"]
            metadata = [label_names[i] for i in y_true.argmax(dim=1).view(-1)]
            label_img = torch.nn.functional.unfold(x, kernel_size=4, stride=4, padding=0).view(-1, 4, 4, 4, 60*60)[:, :3, :, :, :].permute(0, 4, 1, 2, 3).reshape(-1, 3, 4, 4)
            writer.add_embedding(mat, metadata=metadata, label_img=label_img, global_step=epoch+1, tag='Unet-e3', metadata_header=None)

            mat = d4.permute(0, 2, 3, 1).reshape(-1, 256)
            writer.add_embedding(mat, metadata=metadata, label_img=label_img, global_step=epoch+1, tag='Unet-d4', metadata_header=None)

            C, N, H, W = anchors.shape
            a = anchors.detach().cpu().numpy()
            a0 = np.zeros((C, N, H+10, W+10))
            a0[:, :, 5:H+5, 5:W+5] = a
            a0 = a0.transpose(2, 0, 3, 1).reshape(H+10, (W+10) * 5, 4)
            a0 = a0[:, :, :3]
            writer.add_image("validation-anchors", a0, epoch + 1, dataformats='HWC')

        y_true = argmax_pooling(y_true)  # B * N_classes * H/8 * W/8
        l_fs += feature_loss(e4, e4_a, y_true)
        l_fs += feature_loss(d5, d5_a, y_true)

        y_true = argmax_pooling(y_true)  # B * N_classes * H/16 * W/16
        l_fs += feature_loss(e5, e5_a, y_true)

        return l_ce + 0.2 * l_fs

    return loss



# def edge_loss():
#     def loss(edge, area, y_pred, y_true):
#         B, H, W = y_true.shape

#         y_true_edge = y_true.float().unsqueeze(dim=1)
#         edge_true = subtraction2_refpad(y_true_edge, y_true_edge, kernel_size=3, padding=1).abs().sum(dim=2) != 0
#         edge_true = edge_true.float().view(B, -1)
#         edge = edge.view(B, -1)

#         edge_loss = F.binary_cross_entropy_with_logits(edge, edge_true, reduction="mean")
#         area_loss = F.cross_entropy(area, y_true)
#         pred_loss = F.cross_entropy(y_pred, y_true)

#         return edge_loss * 0.2 + area_loss * 0.1 + pred_loss * 1
#     return loss
    
    
def edge_loss():
    def loss(y_pred, y_true):
        B, H, W = y_true.shape

        y_true_edge = y_true.float().unsqueeze(dim=1)
        edge_true = subtraction2_zeropad(y_true_edge, y_true_edge, kernel_size=3, padding=1).abs().sum(dim=2) != 0
        edge_true = edge_true.reshape(B,1,H,W)
        pred_loss = F.cross_entropy(y_pred*edge_true, y_true)

        return pred_loss
    return loss

def att_loss():
    def loss(att,y_true_onehot,kernel_size = 3):
        '''
        att:B*mh*9*H*W
        label: B*C*H*W  (one hot)
        '''
        att = att[:,0] #head 0
        B,C,H,W = y_true_onehot.shape
        y_true_onehot = y_true_onehot.float()
        att_label = subtraction2_zeropad(y_true_onehot, y_true_onehot, kernel_size=kernel_size, padding=1) #B*C*9*(HW)
        att_label = att_label.abs().sum(dim=1) == 0 ##B*9*(HW)
        att_label = att_label.reshape(B,-1,H,W).float()
        # att_label = att_label/torch.sum(att_label,dim=1,keepdim = True)
        
        return F.binary_cross_entropy(F.sigmoid(att),att_label)
    return loss
    
def re_weight_cross_entropy(y_pred,y_true,weight):
    y_true = F.one_hot(y_true,num_classes=y_pred.shape[1]).permute(0,3,1,2).float()
    y_pred = F.softmax(y_pred,dim = 1)
    return -torch.mean(weight*y_true*torch.log(y_pred))


def kl_loss(pred,target,sign=None):
    '''
    B*C*H1*W1
    B*C*H1*W1
    sign must be bool: B*1*H*W
    '''
    num_class = pred.shape[1]
    pred = F.softmax(pred, dim = 1)
    if sign is not None:
        assert sign.shape[1] == 1
        sign = sign.float()
        # print(torch.mean(sign*target*torch.log(pred+1e-10)),torch.numel(sign)/(torch.sum(sign)+1e-10))
        return -torch.mean(sign*target*torch.log(pred+1e-10))*num_class*torch.numel(sign)/(torch.sum(sign)+1e-10)
    else:
        return -torch.mean(target*torch.log(pred+1e-10))*num_class
    


def mae_loss(pred,target,sign=None):
    '''
    B*C*H1*W1
    B*C*H1*W1
    sign must be bool: B*1*H*W
    '''
    num_class = pred.shape[1]
    pred = F.softmax(pred, dim = 1)
    
    if sign is not None:
        assert sign.shape[1] == 1
        sign = sign.float()
        return torch.mean(torch.abs(pred-target)*sign)*torch.numel(sign)/(torch.sum(sign)+1e-10)
    else:
        return torch.mean(torch.abs(pred-target))
    

def mae_loss_woSOFT(pred,target,sign=None):
    '''
    B*C*H1*W1
    B*C*H1*W1
    sign must be bool: B*1*H*W
    '''
    num_class = pred.shape[1]
    
    if sign is not None:
        assert sign.shape[1] == 1
        sign = sign.float()
        return torch.mean(torch.abs(pred-target)*sign)*torch.numel(sign)/(torch.sum(sign)+1e-10)
    else:
        return torch.mean(torch.abs(pred-target))
    
    
def layer_edge_cross_entropy(y_pred,y_true,ks = 5,rate = 0.5):
    '''
    y_pred:B*class*H*W
    '''
    # num_classes = y_pred.shape[1]
    # y_true = torch.randint(0,5,(1,10,10))
    # y_pred = torch.randn(1,5,10,10)
    # num_classes = y_pred.shape[1]
    # ks = 5
    # rate = 0.5
    num_classes = y_pred.shape[1]
        
    y_true = F.one_hot(y_true,num_classes=num_classes).float() #.permute(0,3,1,2)
    y_sur = F.avg_pool3d(y_true,kernel_size = (ks,ks,1),stride = 1,padding = (ks//2,ks//2,0))
    sign = y_sur[y_true==1] > rate
    sign = sign.reshape(y_pred.shape[0],y_pred.shape[2],y_pred.shape[3])

    replace_label = torch.zeros(num_classes).to(y_true.device)
    replace_label[0] = 1
    y_true[sign] = replace_label

    y_pred = F.softmax(y_pred,dim = 1)
    y_true = y_true.permute(0,3,1,2)

    return -torch.mean(y_true*torch.log(y_pred))*num_classes

def focal_cross_entropy(y_pred,y_true,gamma = 1):
    '''
    y_pred: B*Class*H*W
    y_true: B*H*W  (Long, max = Class - 1)
    weight: 
    '''
    cls_num = y_pred.shape[1]
    y_true = F.one_hot(y_true,num_classes=y_pred.shape[1]).permute(0,3,1,2).float()
    y_pred = F.softmax(y_pred,dim = 1)
    weight = (1-y_pred)**gamma
    a1 = (-torch.mean(y_true*torch.log(y_pred))*cls_num).detach()
    a2 = (-torch.mean(weight*y_true*torch.log(y_pred))*cls_num).detach()
    Z = a1/a2
    return -Z*torch.mean(weight*y_true*torch.log(y_pred))*cls_num
    

def adaptive_edge_loss():
    def loss(x,y_pred, y_true):
        ###x: B*C*H*W
        B, H, W = y_true.shape
        
        delta = subtraction2_refpad(x,x,kernel_size=3,padding=1) #B*C
        delta = torch.sum(torch.abs(delta),dim = (1,2)).reshape(B,1,H,W)
        delta = torch.exp(delta)
        delta = torch.clamp(delta,1,1)
        
        pred_loss = re_weight_cross_entropy(y_pred,y_true,delta)
        return pred_loss
    return loss



def edge_loss_with_weight():
    def loss(edge, area, y_pred, y_true):
        B, H, W = y_true.shape

        y_true_edge = y_true.float().unsqueeze(dim=1)
        edge_true = subtraction2_refpad(y_true_edge, y_true_edge, kernel_size=3, padding=1).abs().sum(dim=2) != 0
        edge_true = edge_true.float().view(B,1, H,W)

        kernel = torch.ones((1, 1, 5, 5)).float().cuda()
        weight = F.conv2d(edge_true, kernel, stride=1, padding=2).squeeze()
        weight = (weight + edge_true * 25).clamp(max=10, min=1)

        edge_loss = F.binary_cross_entropy_with_logits(edge, edge_true, reduction="mean")
        area_loss = F.cross_entropy(area, y_true)
        pred_loss = F.cross_entropy(y_pred, y_true, reduction="none")
        pred_loss = (pred_loss * weight).mean()

        return edge_loss * 0.2 + area_loss * 0.1 + pred_loss * 1
    return loss


def crossentropy_loss():
    # y_true [N, H, W], which is not a one-hot vector
    def loss(y_pred, y_true):
        # y_pred = y_pred / (y_pred.norm(p=2, dim=1, keepdim=True) + 1e-6)

        l1 = torch.nn.functional.cross_entropy(y_pred, y_true)
        # l2 = y_pred.norm(p=1, dim=1).mean()

        # return l1 + l2 * 0.1
        return l1

    return loss



def multi_jaccard(num_classes, rate=0.6, smooth=0.001):
    # y_true [N, num_classes, H, W], which is a one-hot vector
    def j(y_pred, y_true):
        y_pred = y_pred[:, 1:, :, :]
        y_true = y_true[:, 1:, :, :]
        B,C,H,W = y_pred.shape

        sign = torch.sum(y_true,dim = (2,3),keepdim=True)
        sign = torch.max(sign,axis = 1,keepdim=True)[0]
        intersection = y_true * y_pred * (sign>(rate*H*W))
        sum_ = (y_true + y_pred) * (sign>(rate*H*W))
        jac1 = torch.sum(intersection+ smooth, dim=(0, 2, 3)) / \
            (torch.sum(sum_ - intersection+ smooth, dim=(0, 2, 3)) )
        counter1 = torch.sum(sign>(rate*H*W))

        intersection = y_true * y_pred * (sign<=(rate*H*W))
        sum_ = (y_true + y_pred) * (sign<=(rate*H*W))
        jac2 = torch.sum(intersection + smooth, dim=(0, 2, 3)) / \
            (torch.sum(sum_ - intersection + smooth, dim=(0, 2, 3)))
        counter2 = torch.sum(sign<=(rate*H*W))

        return torch.sum(jac1) / (num_classes - 1),torch.sum(jac2) / (num_classes - 1),counter1,counter2

    return j





def multi_accuracy(num_classes, rate=0.6):
    # y_true [N, num_classes, H, W], which is a one-hot vector
    def acc(y_pred, y_true):
        y_pred = y_pred[:, 1:, :, :]
        y_true = y_true[:, 1:, :, :]
        B,C,H,W = y_pred.shape

        sign = torch.sum(y_true,dim = (2,3),keepdim=True)
        sign = torch.max(sign,axis = 1,keepdim=True)[0]


        correct_pixels = torch.sum(
            torch.eq(
                torch.argmax(y_pred, dim=1,keepdim=True),
                torch.argmax(y_true, dim=1,keepdim=True)
            )*(sign>(rate*H*W))
        ).float()
        accuracy1 = correct_pixels/((torch.sum(sign>(rate*H*W)))*H*W+1e-3)
        counter1 = torch.sum(sign>(rate*H*W))


        correct_pixels = torch.sum(
            torch.eq(
                torch.argmax(y_pred, dim=1,keepdim=True),
                torch.argmax(y_true, dim=1,keepdim=True)
            )*(sign<=(rate*H*W))
        ).float()
        accuracy2 = correct_pixels/((torch.sum(sign<=(rate*H*W)))*H*W+1e-3)
        counter2 = torch.sum(sign<=(rate*H*W))

        return accuracy1,accuracy2,counter1,counter2

    return acc


def edge_inter_union():
    def loss(y_pred_onehot, y_true_onehot):
        '''
        B*C*H1*W1
        B*C*H1*W1
        '''
        expand_size = 5
        
        edge = F.avg_pool2d(y_true_onehot.float(),kernel_size = 3,stride = 1,padding=1)
        edge = (torch.max(edge,dim = 1,keepdim=True)[0]<1).float()
        edge = F.max_pool2d(edge.float(),kernel_size = expand_size,stride = 1,padding = expand_size//2).bool()
        
        inter = y_pred_onehot * y_true_onehot
        union = y_pred_onehot + y_true_onehot - inter
        
        inter = torch.sum(inter*edge)
        union = torch.sum(union*edge)
        
        return inter,union
    return loss


def jaccard(num_classes, smooth=0.001):
    # y_true [N, num_classes, H, W], which is a one-hot vector
    def j(y_pred, y_true):
        y_pred = y_pred[:, 1:, :, :]
        y_true = y_true[:, 1:, :, :]
        intersection = y_true * y_pred
        sum_ = y_true + y_pred
        jac = torch.sum(intersection+ smooth, dim=(0, 2, 3)) / \
            (torch.sum(sum_ - intersection + smooth, dim=(0, 2, 3)))
        return torch.sum(jac) / (num_classes - 1)

    return j



def accuracy():
    # y_true [N, num_classes, H, W], which is a one-hot vector
    def acc(y_pred, y_true):
        correct_pixels = torch.sum(
            torch.eq(
                torch.argmax(y_pred[:, 1:, :, :], dim=1),
                torch.argmax(y_true[:, 1:, :, :], dim=1)
            ).float()
        )
        all_pixels = y_true.shape[0] * y_true.shape[2] * y_true.shape[3]
        return correct_pixels / (all_pixels + 1)

    return acc


def sr_loss(device, loss_type):
    """Calculate superres loss according to ICLR paper"""
    nlcd_class_weights, nlcd_means, nlcd_vars = load_nlcd_stats()
    nlcd_class_weights = torch.tensor(nlcd_class_weights).to(device)
    nlcd_means = torch.tensor(nlcd_means).to(device)
    nlcd_vars = torch.tensor(nlcd_vars).to(device)

    def ddist(prediction, c_interval_center, c_interval_radius):
        return F.relu(torch.abs(prediction - c_interval_center) - c_interval_radius)

    def loss(y_pred, y_true):
        super_res_crit = 0
        mask_size = torch.unsqueeze(
            torch.sum(y_true, dim=(1, 2, 3)) + 10, dim=-1)  # shape 16x1

        for nlcd_idx in range(nlcd_class_weights.shape[0]):
            c_mask = torch.unsqueeze(
                y_true[:, nlcd_idx, :, :], dim=1)  # shape 16x1x240x240
            c_mask_size = torch.sum(c_mask, dim=(2, 3)) + \
                0.000001  # shape 16x1

            c_interval_center = nlcd_means[nlcd_idx]  # shape 5,
            c_interval_radius = nlcd_vars[nlcd_idx]  # shape 5,

            masked_probs = (
                y_pred * c_mask
            )  # (16x5x240x240) * (16x1x240x240) --> shape (16x5x240x240)

            # Mean mean of predicted distribution
            mean = (
                torch.sum(masked_probs, dim=(2, 3)) / c_mask_size
            )  # (16x5) / (16,1) --> shape 16x5

            # Mean var of predicted distribution
            var = torch.sum(masked_probs * (1.0 - masked_probs), dim=(2, 3)) / (
                c_mask_size * c_mask_size
            )  # (16x5) / (16,1) --> shape 16x5

            c_super_res_crit = torch.square(
                ddist(mean, c_interval_center, c_interval_radius)
            )  # calculate numerator of equation 7 in ICLR paper

            c_super_res_crit = c_super_res_crit / (
                var + (c_interval_radius * c_interval_radius) + 0.000001
            )  # calculate denominator

            c_super_res_crit = c_super_res_crit + torch.log(
                var + 0.03
            )  # calculate log term

            c_super_res_crit = (
                c_super_res_crit
                * (c_mask_size / mask_size)
                * nlcd_class_weights[nlcd_idx]
            )  # weight by the fraction of NLCD pixels and the NLCD class weight

            super_res_crit = super_res_crit + c_super_res_crit

        # super_res_crit = torch.sum(
        #     super_res_crit, dim=1
        # )  # sum superres loss across highres classes
        # return super_res_crit

        # Calculate the average superres loss. Added by Qiyuan.
        super_res_crit = super_res_crit / \
            nlcd_class_weights.shape[0]  # output 16x5
        super_res_crit = torch.mean(
            super_res_crit, dim=(0, 1))  # output a scalar

        return super_res_crit

    def kl_loss(y_pred, y_true):
        overall_kld = 0
        mask_size = torch.unsqueeze(
            torch.sum(y_true, dim=(1, 2, 3)) + 10, dim=-1)  # shape 16x1

        for nlcd_idx in range(nlcd_class_weights.shape[0]):
            c_mask = torch.unsqueeze(
                y_true[:, nlcd_idx, :, :], dim=1)  # shape 16x1x240x240
            c_mask_size = torch.sum(c_mask, dim=(2, 3)) + \
                0.000001  # shape 16x1

            c_mean = nlcd_means[nlcd_idx]  # shape 5,
            c_sigma = nlcd_vars[nlcd_idx]  # shape 5,

            masked_probs = (
                y_pred * c_mask
            )  # (16x5x240x240) * (16x1x240x240) --> shape (16x5x240x240)

            # Mean mean of predicted distribution
            mean = (
                torch.sum(masked_probs, dim=(2, 3)) / c_mask_size
            )  # (16x5) / (16,1) --> shape 16x5

            # Mean var of predicted distribution
            var = torch.sum(masked_probs * (1.0 - masked_probs), dim=(2, 3)) / (
                c_mask_size * c_mask_size
            )  # (16x5) / (16,1) --> shape 16x5
            std = torch.sqrt(var)

            # KL Divergence between two Gaussian distributions
            # KL(N(u1, s1^2) || N(u2, s2^2)) equals
            # log(s2 / s1) - 1/2 + (s1^2 + (u1 - u2)^2) / (2 * s2^2)

            eps = 1e-6
            kld = torch.log(((std + eps) / (c_sigma + eps)) + eps) - 0.5
            kld += (torch.square(c_sigma) +
                    torch.square(c_mean - mean)) / (2 * var + eps)

            kld = kld * (c_mask_size / mask_size)
            overall_kld += kld  # 16 * 5

        overall_kld = overall_kld / nlcd_class_weights.shape[0]
        overall_kld = torch.mean(overall_kld, dim=(0, 1))  # output a scalar

        return overall_kld

    return kl_loss if loss_type == "superres-kld" else loss
