# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class DetectionWeightedCeLoss(nn.Module):
    def __init__(self, args):
        super(DetectionWeightedCeLoss, self).__init__()
        self.cls_weight = args["cls_weight"]
        self.box_weight = args["box_weight"]
        self.pos_weights = args["pos_weights"]
        if "cmt_weights" in args:
            self.cmt_weights = args["cmt_weights"]
        self.loss_dict = {}
        self.loss_func_cls = nn.CrossEntropyLoss(weight=torch.Tensor([1., self.pos_weights]).cuda())

    def forward(self, output_dict, target_dict):
        """
        Compute loss for pixor network
        Parameters
        ----------
        output_dict : dict
           The dictionary that contains the output.

        target_dict : dict
           The dictionary that contains the target.

        Returns
        -------
        total_loss : torch.Tensor
            Total loss.

        """
        targets = target_dict["label_map"]
        cls_preds, loc_preds = output_dict["cls"], output_dict["reg"]
        cmt_loss = output_dict["cmt_loss"]

        cls_targets, loc_targets = targets.split([1, 6], dim=1)

        pos_pixels = cls_targets.sum()
        if pos_pixels == 0:
            print('warning')
        loc_loss = F.smooth_l1_loss(cls_targets * loc_preds,
                                    cls_targets * loc_targets,
                                    reduction='sum')

        cls_targets = rearrange(cls_targets, 'b l h w -> (b l) h w')
        cls_loss = self.loss_func_cls(cls_preds, cls_targets.to(int))
        loc_loss = loc_loss / pos_pixels if pos_pixels > 0 else loc_loss
        if cmt_loss is not None:
            total_loss = self.cls_weight * cls_loss + self.box_weight * loc_loss + self.cmt_weights * cmt_loss
        else:
            total_loss = self.cls_weight * cls_loss + self.box_weight * loc_loss

        self.loss_dict.update({'total_loss': total_loss,
                               'reg_loss': loc_loss,
                               'cls_loss': cls_loss,
                               'cmt_loss': cmt_loss})

        return total_loss

    def logging(self, epoch, batch_id, batch_len, writer, pbar=None):
        """
        Print out  the loss function for current iteration.

        Parameters
        ----------
        epoch : int
            Current epoch for training.
        batch_id : int
            The current batch.
        batch_len : int
            Total batch length in one iteration of training,
        writer : SummaryWriter
            Used to visualize on tensorboard
        """
        total_loss = self.loss_dict['total_loss']
        reg_loss = self.loss_dict['reg_loss']
        cls_loss = self.loss_dict['cls_loss']
        cmt_loss = self.loss_dict['cmt_loss']

        if cmt_loss is not None:
            if pbar is None:
                print("[epoch %d][%d/%d], || Loss: %.4f || Conf Loss: %.4f"
                    " || Loc Loss: %.4f || Cmt Loss: %.4f" % (
                        epoch, batch_id + 1, batch_len,
                        total_loss.item(), cls_loss.item(), reg_loss.item(), cmt_loss.item()))
            else:
                pbar.set_description("[epoch %d][%d/%d], || Loss: %.4f || Conf Loss: %.4f"
                      " || Loc Loss: %.4f || Cmt Loss: %.4f" % (
                          epoch, batch_id + 1, batch_len,
                          total_loss.item(), cls_loss.item(), reg_loss.item(), cmt_loss.item()))
            writer.add_scalar('Commitment_loss', cmt_loss.item(),
                              epoch * batch_len + batch_id)
        else:
            if pbar is None:
                print("[epoch %d][%d/%d], || Loss: %.4f || Conf Loss: %.4f"
                    " || Loc Loss: %.4f" % (
                        epoch, batch_id + 1, batch_len,
                        total_loss.item(), cls_loss.item(), reg_loss.item()))
            else:
                pbar.set_description("[epoch %d][%d/%d], || Loss: %.4f || Conf Loss: %.4f"
                      " || Loc Loss: %.4f" % (
                          epoch, batch_id + 1, batch_len,
                          total_loss.item(), cls_loss.item(), reg_loss.item()))

        writer.add_scalar('Regression_loss', reg_loss.item(),
                          epoch * batch_len + batch_id)
        writer.add_scalar('Confidence_loss', cls_loss.item(),
                          epoch * batch_len + batch_id)
