# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torchvision
import math

from .new_vgg_loss import NewVGGLoss


class JointsMSELoss(nn.Module):
    def __init__(self, use_target_weight):
        super(JointsMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='none')
        self.use_target_weight = use_target_weight

    def forward(self, output, target, target_weight, *args, **kwargs):
        is_labeled = kwargs.get('is_labeled', None)
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape(
            (batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
        loss = 0

        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                joint_loss = 0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                )
            else:
                joint_loss = 0.5 * self.criterion(heatmap_pred, heatmap_gt)
            
            # Mask unlabeled samples
            if is_labeled is not None and not is_labeled.all():
                assert False
                if is_labeled.any():
                    joint_loss = joint_loss[is_labeled].mean()
                else:
                    joint_loss = torch.tensor([0.0], requires_grad=True)
            else:
                joint_loss = joint_loss.mean()
            
            loss += joint_loss

        return loss / num_joints


class JointsOHKMMSELoss(nn.Module):
    def __init__(self, use_target_weight, topk=8):
        super(JointsOHKMMSELoss, self).__init__()
        self.criterion = nn.MSELoss(reduction='none')
        self.use_target_weight = use_target_weight
        self.topk = topk

    def ohkm(self, loss):
        ohkm_loss = 0.
        for i in range(loss.size()[0]):
            sub_loss = loss[i]
            topk_val, topk_idx = torch.topk(
                sub_loss, k=self.topk, dim=0, sorted=False
            )
            tmp_loss = torch.gather(sub_loss, 0, topk_idx)
            ohkm_loss += torch.sum(tmp_loss) / self.topk
        ohkm_loss /= loss.size()[0]
        return ohkm_loss

    def forward(self, output, target, target_weight):
        batch_size = output.size(0)
        num_joints = output.size(1)
        heatmaps_pred = output.reshape(
            (batch_size, num_joints, -1)).split(1, 1)
        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)

        loss = []
        for idx in range(num_joints):
            heatmap_pred = heatmaps_pred[idx].squeeze()
            heatmap_gt = heatmaps_gt[idx].squeeze()
            if self.use_target_weight:
                loss.append(0.5 * self.criterion(
                    heatmap_pred.mul(target_weight[:, idx]),
                    heatmap_gt.mul(target_weight[:, idx])
                ))
            else:
                loss.append(
                    0.5 * self.criterion(heatmap_pred, heatmap_gt)
                )

        loss = [l.mean(dim=1).unsqueeze(dim=1) for l in loss]
        loss = torch.cat(loss, dim=1)

        return self.ohkm(loss)


class PerceptualLoss(nn.Module):
    def __init__(self, cfg):
        super(PerceptualLoss, self).__init__()
        self.vgg16 = torchvision.models.vgg16(pretrained=True)
        self.vgg16.eval()
        self.criterion = nn.MSELoss(reduction='mean')

    def forward(self, images, pred_images):

        images = self.vgg16.features(images)
        pred_images = self.vgg16.features(pred_images)

        n = images.size(0)
        images = images.reshape(n, -1)
        pred_images = pred_images.reshape(n, -1)
        loss = self.criterion(pred_images, images)
        return loss


class DoubleLoss(nn.Module):
    def __init__(self, cfg):
        super(DoubleLoss, self).__init__()
        self.sup_task_loss = JointsMSELoss(cfg.LOSS.USE_TARGET_WEIGHT)

        self.new_vgg_loss = cfg.LOSS.NEW_VGG_LOSS
        if cfg.LOSS.NEW_VGG_LOSS:
            self.unsup_task_loss1 = NewVGGLoss(cfg)
        else:
            self.unsup_task_loss1 = PerceptualLoss(cfg)

        self.l2_loss = cfg.LOSS.L2_LOSS
        self.unsup_task_loss2 = nn.MSELoss(reduction='mean')
        self.unsup_loss_weight = cfg.LOSS.UNSUP_LOSS_WEIGHT
        self.crop = cfg.LOSS.CROP
        self.crop_margin = cfg.LOSS.CROP_MARGIN
        self.weight_sin = cfg.LOSS.WEIGHT_SIN
        self.ori_unsup_loss_weight = self.unsup_loss_weight

    def schedule_loss_weight(self, cur_epoch, total_epoch):
        if not self.weight_sin:
            return
        coef = math.sin(cur_epoch / total_epoch * math.pi / 2)
        self.unsup_loss_weight = coef * self.ori_unsup_loss_weight

    def forward(self, output, target, target_weight, images, pred_images,
                joints_pred, is_labeled=None):
        if self.crop:
            assert joints_pred is not None
            joints_pred = joints_pred.detach() * 4
            min_x = joints_pred[:, 0].min(dim=1)[0]
            min_y = joints_pred[:, 1].min(dim=1)[0]
            max_x = joints_pred[:, 0].max(dim=1)[0]
            max_y = joints_pred[:, 1].max(dim=1)[0]
            # two possible solutions here,
            # 1. we really crop the image, then images in the batch have different sizes
            #    then we have to use for loop, which is slow
            # 2. we make the cropped area have weight=0, non-cropped area have weight=1
            #    this is fast, but may have worse performance?
            # Here we choose 2.
            n, _, h, w = images.size()
            min_x = (min_x - self.crop_margin).clamp(0,
                                                     w - 1).reshape(n, 1, 1, 1)
            min_y = (min_y - self.crop_margin).clamp(0,
                                                     h - 1).reshape(n, 1, 1, 1)
            max_x = (max_x + self.crop_margin).clamp(0,
                                                     w - 1).reshape(n, 1, 1, 1)
            max_y = (max_y + self.crop_margin).clamp(0,
                                                     h - 1).reshape(n, 1, 1, 1)
            coord_x = images.new_tensor(range(w)).reshape(1, 1, 1, w)
            coord_y = images.new_tensor(range(h)).reshape(1, 1, h, 1)
            weight = (coord_x >= min_x) * (coord_x <= max_x) * \
                     (coord_y >= min_y) * (coord_y <= max_y)
            images = images * weight
            pred_images = pred_images * weight

        sup_task_loss = self.sup_task_loss(output, target, target_weight, is_labeled=is_labeled)
        if self.new_vgg_loss:
            unsup_task_loss = self.unsup_task_loss1(images, pred_images)
        elif self.l2_loss:
            unsup_task_loss = self.unsup_task_loss2(images, pred_images)
        else:
            unsup_task_loss = 0.5*self.unsup_task_loss1(images, pred_images) + \
                0.5*self.unsup_task_loss2(images, pred_images)
        return sup_task_loss + self.unsup_loss_weight * unsup_task_loss


class FinetuneUnsupLoss(nn.Module):
    def __init__(self, cfg):
        super(FinetuneUnsupLoss, self).__init__()
        if cfg.LOSS.NEW_VGG_LOSS:
            self.unsup_task_loss1 = NewVGGLoss(cfg)
        elif cfg.LOSS.L2_LOSS:
            self.unsup_task_loss1 = nn.MSELoss(reduction='mean')
        else:
            self.unsup_task_loss1 = PerceptualLoss(cfg)

        self.crop = cfg.LOSS.CROP
        self.crop_margin = cfg.LOSS.CROP_MARGIN

    def forward(self, output, target, target_weight, images, pred_images, joints_pred=None):
        if self.crop:
            assert joints_pred is not None
            joints_pred = joints_pred.detach() * 4
            min_x = joints_pred[:, 0].min(dim=1)[0]
            min_y = joints_pred[:, 1].min(dim=1)[0]
            max_x = joints_pred[:, 0].max(dim=1)[0]
            max_y = joints_pred[:, 1].max(dim=1)[0]
            # two possible solutions here,
            # 1. we really crop the image, then images in the batch have different sizes
            #    then we have to use for loop, which is slow
            # 2. we make the cropped area have weight=0, non-cropped area have weight=1
            #    this is fast, but may have worse performance?
            # Here we choose 2.
            n, _, h, w = images.size()
            min_x = (min_x - self.crop_margin).clamp(0,
                                                     w - 1).reshape(n, 1, 1, 1)
            min_y = (min_y - self.crop_margin).clamp(0,
                                                     h - 1).reshape(n, 1, 1, 1)
            max_x = (max_x + self.crop_margin).clamp(0,
                                                     w - 1).reshape(n, 1, 1, 1)
            max_y = (max_y + self.crop_margin).clamp(0,
                                                     h - 1).reshape(n, 1, 1, 1)
            coord_x = images.new_tensor(range(w)).reshape(1, 1, 1, w)
            coord_y = images.new_tensor(range(h)).reshape(1, 1, h, 1)
            weight = (coord_x >= min_x) * (coord_x <= max_x) * \
                     (coord_y >= min_y) * (coord_y <= max_y)
            images = images * weight
            pred_images = pred_images * weight

        unsup_task_loss = self.unsup_task_loss1(images, pred_images)
        return unsup_task_loss

class FinetuneDoubleLoss(nn.Module):
    """
    This loss is the same as double loss except we only compute sup loss for 
        the last sample in every batch.
    """
    def __init__(self, cfg):
        super(FinetuneDoubleLoss, self).__init__()
        self.sup_task_loss = JointsMSELoss(cfg.LOSS.USE_TARGET_WEIGHT)

        self.new_vgg_loss = cfg.LOSS.NEW_VGG_LOSS
        if cfg.LOSS.NEW_VGG_LOSS:
            self.unsup_task_loss1 = NewVGGLoss(cfg)
        else:
            self.unsup_task_loss1 = PerceptualLoss(cfg)

        self.l2_loss = cfg.LOSS.L2_LOSS
        self.unsup_task_loss2 = nn.MSELoss(reduction='mean')
        self.unsup_loss_weight = cfg.LOSS.UNSUP_LOSS_WEIGHT
        assert not cfg.LOSS.CROP
        self.crop_margin = cfg.LOSS.CROP_MARGIN

    def forward(self, output, target, target_weight, images, pred_images, joints_pred=None):

        sup_task_loss = self.sup_task_loss(output[[-1]], target[[-1]], target_weight[[-1]])
        if self.new_vgg_loss:
            unsup_task_loss = self.unsup_task_loss1(images, pred_images)
        elif self.l2_loss:
            unsup_task_loss = self.unsup_task_loss2(images, pred_images)
        else:
            unsup_task_loss = 0.5*self.unsup_task_loss1(images, pred_images) + \
                0.5*self.unsup_task_loss2(images, pred_images)
        return sup_task_loss + self.unsup_loss_weight * unsup_task_loss

class FinetuneSupLoss(nn.Module):
    """
    This loss is the same as double loss except we only compute sup loss for 
        the last sample in every batch.
    """
    def __init__(self, cfg):
        super(FinetuneSupLoss, self).__init__()
        self.sup_task_loss = JointsMSELoss(cfg.LOSS.USE_TARGET_WEIGHT)

    def forward(self, output, target, target_weight, *args, **kwargs):

        sup_task_loss = self.sup_task_loss(output[[-1]], target[[-1]], target_weight[[-1]])

        return sup_task_loss