import torch
from torch import nn
import torch.nn.functional as F


class REG_Criterion(nn.Module):
    def __init__(self):
        super(REG_Criterion, self).__init__()

    def forward(self, pred_point, gt_point):
        loss_map = torch.sum(torch.abs(pred_point - gt_point), axis=-1, keepdims=True)
        loss_map = torch.mean(loss_map)

        return loss_map


class CLS_Criterion(nn.Module):
    def __init__(self):
        super(CLS_Criterion, self).__init__()
        self.loc_loss = nn.CrossEntropyLoss(label_smoothing=0.1)

    def forward(self, pred_loc, gt_loc):
        loss = self.loc_loss(pred_loc, gt_loc.long())

        return loss
    

class RSD_Criterion(nn.Module):
    def __init__(self, first_prune_epoch, second_prune_epoch, windows):
        super(RSD_Criterion, self).__init__()
        self.first_prune_epoch = first_prune_epoch
        self.second_prune_epoch = second_prune_epoch
        self.windows = windows

    def forward(self, pred_point, gt_point, batch_size, epoch_nums, idx, values, pred_point_stu, epochs):
        loss_map = torch.sum(torch.abs(pred_point - gt_point), axis=-1)  # [B*N]
        loss = loss_map.detach().clone().view(batch_size, -1)   # [B, N]
        if self.first_prune_epoch <= epoch_nums < (self.first_prune_epoch + self.windows):
            current_epoch = epoch_nums - self.first_prune_epoch
            values[idx, current_epoch] = torch.median(loss, dim=1).values
        elif self.second_prune_epoch <= epoch_nums < (self.second_prune_epoch + self.windows):
            current_epoch = epoch_nums - self.second_prune_epoch
            values[idx, current_epoch] = torch.median(loss, dim=1).values
        loss_map = torch.mean(loss_map)

        loss_map_stu = torch.sum(torch.abs(pred_point_stu - gt_point), axis=-1)  # [B*N]
        loss_map_stu = torch.mean(loss_map_stu)

        if epoch_nums >= epochs-1:
            loss_map_dis = torch.sum(torch.abs(pred_point_stu - pred_point.detach()), axis=-1)  # [B*N]
            loss_map_dis = torch.mean(loss_map_dis)
        else:
            loss_map_dis = 0

        loss_all = 1 * loss_map + 1 * loss_map_stu + 0.1 * loss_map_dis
        # print(loss_map_dis)
        
        return loss_all, values