# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class How2commDetLoss(nn.Module):
    def __init__(self, args):
        super(How2commDetLoss, self).__init__()
        self.cls_weight = args["cls_weight"]
        self.box_weight = args["box_weight"]
        self.pos_weights = args["pos_weights"]
        self.commu_weight = args["commu_weight"]
        self.loss_dict = {}
        self.loss_func_cls = nn.CrossEntropyLoss(weight=torch.Tensor([1., self.pos_weights]).cuda())

    def forward(self, output_dict, target_dict):
        """
        Compute loss for pixor network
        Parameters
        ----------
        output_dict : dict
           The dictionary that contains the output.

        target_dict : dict
           The dictionary that contains the target.

        Returns
        -------
        total_loss : torch.Tensor
            Total loss.

        """
        targets = target_dict["label_map"]
        cls_preds, loc_preds = output_dict["cls"], output_dict["reg"]
        commu_loss = output_dict["commu_loss"]

        cls_targets, loc_targets = targets.split([1, 6], dim=1)

        pos_pixels = cls_targets.sum()

        loc_loss = F.smooth_l1_loss(cls_targets * loc_preds,
                                    cls_targets * loc_targets,
                                    reduction='sum')

        cls_targets = rearrange(cls_targets, 'b l h w -> (b l) h w')
        cls_loss = self.loss_func_cls(cls_preds, cls_targets.to(int))
        loc_loss = loc_loss / pos_pixels if pos_pixels > 0 else loc_loss

        total_loss = self.cls_weight * cls_loss + self.box_weight * loc_loss + self.commu_weight * commu_loss

        self.loss_dict.update({'total_loss': total_loss,
                               'reg_loss': loc_loss,
                               'cls_loss': cls_loss,
                               'commu_loss': commu_loss,
                               'cmt_loss': None})

        return total_loss

    def logging(self, epoch, batch_id, batch_len, writer, pbar=None):
        """
        Print out  the loss function for current iteration.

        Parameters
        ----------
        epoch : int
            Current epoch for training.
        batch_id : int
            The current batch.
        batch_len : int
            Total batch length in one iteration of training,
        writer : SummaryWriter
            Used to visualize on tensorboard
        """
        total_loss = self.loss_dict['total_loss']
        reg_loss = self.loss_dict['reg_loss']
        cls_loss = self.loss_dict['cls_loss']

        if pbar is None:
            print("[epoch %d][%d/%d], || Loss: %.4f || Conf Loss: %.4f"
                " || Loc Loss: %.4f" % (
                    epoch, batch_id + 1, batch_len,
                    total_loss.item(), cls_loss.item(), reg_loss.item()))
        else:
            pbar.set_description("[epoch %d][%d/%d], || Loss: %.4f || Conf Loss: %.4f"
                  " || Loc Loss: %.4f" % (
                      epoch, batch_id + 1, batch_len,
                      total_loss.item(), cls_loss.item(), reg_loss.item()))

        writer.add_scalar('Regression_loss', reg_loss.item(),
                          epoch * batch_len + batch_id)
        writer.add_scalar('Confidence_loss', cls_loss.item(),
                          epoch * batch_len + batch_id)