import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from torchvision.ops import sigmoid_focal_loss

class SegmentFocalLoss(nn.Module):
    def __init__(self, args):
        super(SegmentFocalLoss, self).__init__()
        self.gamma = args['gamma']
        self.loss_dict = {}
        self.seg_weight = args['seg_weight']
        if 'cmt_weight' in args:
            self.cmt_weight = args['cmt_weight']

    def forward(self, output_dict, target_map):
        """
        Perform loss function on the prediction.

        Parameters
        ----------
        output_dict : dict
            The dictionary contains the prediction.

        gt_dict : dict
            The dictionary contains the groundtruth.

        Returns
        -------
        Loss dictionary.
        """
        output_map = output_dict["seg"]
        try:
            cmt_loss = output_dict["cmt_loss"]
        except:
            pass
        index = target_map.unique()

        one_hot_map = F.one_hot(target_map.to(int))
        if one_hot_map.shape[-1] < 4:
            for i in range(4):
                if i not in index:
                    zeros_to_insert = torch.zeros(1, 256, 256, 1).to(one_hot_map.device)

                    part1 = one_hot_map[..., 0:i]
                    part2 = one_hot_map[..., i:]

                    one_hot_map = torch.cat((part1, zeros_to_insert, part2), dim=-1)
        one_hot_map = rearrange(one_hot_map, 'b h w l -> b l h w')

        seg_loss = sigmoid_focal_loss(output_map, one_hot_map.float(), gamma=self.gamma, reduction='mean')

        try:
            total_loss = seg_loss * self.seg_weight + cmt_loss * self.cmt_weight
        except:
            total_loss = seg_loss * self.seg_weight
        self.loss_dict.update({'seg_loss': seg_loss})

        return total_loss

    def logging(self, epoch, batch_id, batch_len, writer, pbar=None):
        """
        Print out  the loss function for current iteration.

        Parameters
        ----------
        epoch : int
            Current epoch for training.
        batch_id : int
            The current batch.
        batch_len : int
            Total batch length in one iteration of training,
        writer : SummaryWriter
            Used to visualize on tensorboard
        """
        seg_loss = self.loss_dict['seg_loss']

        if pbar is None:
            print("[epoch %d][%d/%d], || Segmentation Loss: %.4f" % (
                    epoch, batch_id + 1, batch_len,
                    seg_loss.item()))
        else:
            pbar.set_description("[epoch %d][%d/%d], || Segmentation Loss: %.4f" % (
                    epoch, batch_id + 1, batch_len,
                    seg_loss.item()))

        writer.add_scalar('Segmentation_loss', seg_loss.item(),
                          epoch*batch_len + batch_id)