from __future__ import print_function, division

from .base_loss import BaseLoss
from . import OPENOCC_LOSS
import torch.nn.functional as F
from transformers import AutoImageProcessor, AutoModel
from torchvision import transforms
import numpy as np
import sys, os, pdb

class ForkedPdb(pdb.Pdb):
    """A Pdb subclass that may be used
    from a forked multiprocessing child

    """
    def interaction(self, *args, **kwargs):
        _stdin = sys.stdin
        try:
            sys.stdin = open('/dev/stdin')
            pdb.Pdb.interaction(self, *args, **kwargs)
        finally:
            sys.stdin = _stdin

@OPENOCC_LOSS.register_module()
class PlaneLoss(BaseLoss):
    def __init__(self, weight=1.0, reg_fff=1e-6, input_dict=None, **kwargs):
        super().__init__(weight)

        if input_dict is None:
            self.input_dict = {
                'mask': 'hexplane_mask',
                'logits': 'logits',
                'labels': 'labels'
            }
        else:
            self.input_dict = input_dict
        self.reg_fff = reg_fff
        self.loss_func = self.plane_loss

    def plane_loss(self, mask, logits, labels):
        # [[12, 8, 100, 100], [12, 8, 100, 16], [12, 8, 100, 16]]
        # print("l1:", F.l1_loss(logits[0], labels[0]).item(), "l2:", F.mse_loss(logits[0], labels[0]).item(), "cos:", F.cosine_similarity(logits[0], labels[0]).mean().item())
        # xy_loss = F.l1_loss(logits[0], labels[0]) 
        # xz_loss = F.l1_loss(logits[1], labels[1])
        # yz_loss = F.l1_loss(logits[2], labels[2])

        xy_loss = F.mse_loss(logits[0], labels[0]) + F.l1_loss(logits[0], labels[0]) 
        xz_loss = F.mse_loss(logits[1], labels[1]) + F.l1_loss(logits[1], labels[1])
        yz_loss = F.mse_loss(logits[2], labels[2]) + F.l1_loss(logits[2], labels[2])

        # add cos-sim
        # cos_f = 0.1
        # xy_loss = 1-F.cosine_similarity(logits[0], labels[0]).mean() + xy_loss
        # xz_loss = 1-F.cosine_similarity(logits[1], labels[1]).mean() + xz_loss
        # yz_loss = 1-F.cosine_similarity(logits[2], labels[2]).mean() + yz_loss

        loss = xy_loss + xz_loss + yz_loss

        # xy_loss = F.l1_loss(logits[0], labels[0], reduction='none') * mask[0]
        # xz_loss = F.l1_loss(logits[1], labels[1], reduction='none') * mask[1]
        # yz_loss = F.l1_loss(logits[2], labels[2], reduction='none') * mask[2]  
        # eff_xy = mask[0].sum()
        # eff_xz = mask[1].sum()
        # eff_yz = mask[2].sum()
        # xy_loss = torch.sum(xy_loss) / eff_xy
        # xz_loss = torch.sum(xz_loss) / eff_xz
        # yz_loss = torch.sum(yz_loss) / eff_yz
        # sparse_reg = torch.sum((1 - mask[0]) ** 2) + \
        #                 torch.sum((1 - mask[1]) ** 2) + \
        #                     torch.sum((1 - mask[2]) ** 2)
                            
        # loss = xy_loss + xz_loss + yz_loss + self.reg_fff * sparse_reg
        # print("eff:", eff_xy.item(), eff_xz.item(), eff_yz.item(), xy_loss.item(), sparse_reg.item())

        return loss


@OPENOCC_LOSS.register_module()
class ReconLoss(BaseLoss):

    def __init__(self, weight=1.0, ignore_label=-100, use_weight=False, cls_weight=None, input_dict=None, **kwargs):
        super().__init__(weight)

        if input_dict is None:
            self.input_dict = {
                'logits': 'logits',
                'labels': 'labels'
            }
        else:
            self.input_dict = input_dict
        self.loss_func = self.recon_loss
        self.ignore = ignore_label
        self.use_weight = use_weight
        self.cls_weight = torch.tensor(cls_weight) if cls_weight is not None else None
    
    def recon_loss(self, logits, labels):
        weight = None
        if self.use_weight:
            if self.cls_weight is not None:
                weight = self.cls_weight
            else:
                one_hot_labels = F.one_hot(labels, num_classes=logits.shape[-1]) # bs, F, H, W, D, C
                cls_freq = torch.sum(one_hot_labels, dim=[0, 1, 2, 3, 4]) # C
                weight = 1.0 / cls_freq.clamp_min_(1) * torch.numel(labels) / logits.shape[-1]
        
        rec_loss = F.cross_entropy(logits.permute(0, 5, 1, 2, 3, 4), labels, ignore_index=self.ignore, weight=weight)
        return rec_loss
    
@OPENOCC_LOSS.register_module()
class LovaszLoss(BaseLoss):

    def __init__(self, weight=1.0, input_dict=None, **kwargs):
        super().__init__(weight)

        if input_dict is None:
            self.input_dict = {
                'logits': 'logits',
                'labels': 'labels'
            }
        else:
            self.input_dict = input_dict
        self.loss_func = self.lovasz_loss
    
    def lovasz_loss(self, logits, labels):
        # occworld
        # logits = logits.flatten(0, 1).permute(0, 4, 1, 2, 3).softmax(dim=1)
        # labels = labels.flatten(0, 1)
        # loss = lovasz_softmax(logits, labels)
        
        # ours
        loss = lovasz_softmax(logits.permute(0,4,1,2,3), labels)
        return loss

    
@OPENOCC_LOSS.register_module()
class ImgLoss(BaseLoss):
    def __init__(self, weight=1.0, input_dict=None, **kwargs):
        super().__init__(weight)

        if input_dict is None:
            self.input_dict = {
                'logits': 'logits',
                'labels': 'labels'
            }
        else:
            self.input_dict = input_dict

        self.loss_func = self.img_loss

        self.processor =  transforms.Compose([
            # transforms.Resize((288, 512)),
            transforms.Resize((299, 297)),
            transforms.ToTensor()        
        ])
            
    def img_loss(self, logits, labels):
        # logits: [4, 6, 3, 288, 512] 0 to 1

        # bs, camnum, d, h, w = logits.shape
        # logits = logits.view(bs*camnum, d, h, w)

        # [18, 3, 299, 297]
        labels = [self.processor(one_label) for one_label in labels]
        labels = torch.stack(labels).to(logits.device)    # 0 to 1

        # test for grayscale
        # labels = transforms.Grayscale(num_output_channels=1)(labels)

        transforms.ToPILImage()(labels[0]).save("./test_results/ttttttttt/label0.png")
        transforms.ToPILImage()(logits[0]).save("./test_results/ttttttttt/logit0.png")
        transforms.ToPILImage()(labels[2]).save("./test_results/ttttttttt/label2.png")
        transforms.ToPILImage()(logits[2]).save("./test_results/ttttttttt/logit2.png")
        transforms.ToPILImage()(labels[4]).save("./test_results/ttttttttt/label4.png")
        transforms.ToPILImage()(logits[4]).save("./test_results/ttttttttt/logit4.png")
        # loss = F.l1_loss(logits, labels, reduction='mean')
        # loss = F.l1_loss(logits, labels)
        # loss = F.l1_loss(logits, labels) + F.mse_loss(logits, labels)

        loss = F.mse_loss(logits, labels)
        return loss

"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""

import torch
from torch.autograd import Variable
import torch.nn.functional as F
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse as ifilterfalse


def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard

# --------------------------- MULTICLASS LOSSES ---------------------------


def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
                          for prob, lab in zip(probas, labels))
    else:
        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
    return loss


def lovasz_softmax_flat(probas, labels, classes='present'):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    """
    if probas.numel() == 0:
        # only void pixels, the gradients should be 0
        return 0.#probas * 0.
    #print(probas.size())
    C = probas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
    for c in class_to_sum:
        fg = (labels == c).float() # foreground for class c
        if (classes == 'present' and fg.sum() == 0):
            continue
        if C == 1:
            if len(classes) > 1:
                raise ValueError('Sigmoid output possible only with 1 class')
            class_pred = probas[:, 0]
        else:
            class_pred = probas[:, c]
        errors = (Variable(fg) - class_pred).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    if probas.dim() == 3:
        # assumes output of a sigmoid layer
        B, H, W = probas.size()
        probas = probas.view(B, 1, H, W)
    elif probas.dim() == 5:
        #3D segmentation
        B, C, L, H, W = probas.size()
        probas = probas.contiguous().view(B, C, L, H*W)
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = (labels != ignore)
    vprobas = probas[valid]#.nonzero().squeeze()]
    # print(labels)
    # print(valid)
    vlabels = labels[valid]
    return vprobas, vlabels

# --------------------------- HELPER FUNCTIONS ---------------------------

def isnan(x):
    return x != x    
    
def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n
