import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp

from config.registry import Registry
from .region_loss import (
    BCEWithRegionKL, BCEWithRegionMSE, DiceWithRegionKL, BCEWithRegionLog,
    PartialBCEWithRegionKL, PartialBCEWithRegionKLv2, BCEWithDiceWithLogitsLoss, BCEWithRegionMAE
)
from .base import KlLoss, WeightedBCE
from .region_loss import DiceWithLogitsLoss
from .region_loss2 import CEWithRegionL1, CEWithRegionDice, CEWithRegionKL, CEWithRegionL2
from .dice import DiceLoss, GeneralisedDiceLoss

logger = logging.getLogger(__name__)

LOSS_REGISTRY = Registry("loss")

LOSS_REGISTRY.register("mse", lambda cfg : nn.MSELoss())
LOSS_REGISTRY.register("kl", lambda cfg : KlLoss())
LOSS_REGISTRY.register(
    "cross_entropy",
    lambda cfg : nn.CrossEntropyLoss(
        weight=torch.FloatTensor(cfg.LOSS.CLASS_WEIGHTS) if cfg.LOSS.CLASS_WEIGHTS else None,
        ignore_index=cfg.LOSS.IGNORE_INDEX
    )
)
LOSS_REGISTRY.register("bce_logit", lambda cfg : nn.BCEWithLogitsLoss())
LOSS_REGISTRY.register("weighted_bce", lambda cfg : WeightedBCE())
LOSS_REGISTRY.register(
    "soft_bce_logit",
    lambda cfg : smp.losses.SoftBCEWithLogitsLoss(smooth_factor=cfg.LOSS.LABEL_SMOOTHING)
)
# LOSS_REGISTRY.register("dice_logit", lambda cfg : DiceWithLogitsLoss())
LOSS_REGISTRY.register(
    "dice_logit",
    lambda cfg : DiceLoss(
        mode="binary",
    )
)

LOSS_REGISTRY.register("generalised_dice_logit", lambda cfg : GeneralisedDiceLoss())

# LOSS_REGISTRY.register("logdice_logit", lambda cfg : DiceWithLogitsLoss(log_loss=True))
LOSS_REGISTRY.register("logdice_logit",
                       lambda cfg : DiceLoss(mode="binary", log_loss=True))

LOSS_REGISTRY.register(
    "dice_loss",
    lambda cfg : DiceLoss(
        mode="multiclass",
        log_loss=False,
        ignore_index=cfg.LOSS.IGNORE_INDEX,
    )
)
LOSS_REGISTRY.register(
    "logdice_loss",
    lambda cfg : DiceLoss(
        mode="multiclass", log_loss=True, ignore_index=cfg.LOSS.IGNORE_INDEX
    )
)
LOSS_REGISTRY.register(
    "ce_region_dice",
    lambda cfg : CEWithRegionDice(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        ignore_index=cfg.LOSS.IGNORE_INDEX
    )
)
LOSS_REGISTRY.register(
    "ce_region_kl",
    lambda cfg : CEWithRegionKL(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        temp=cfg.LOSS.TEMP,
        ignore_index=cfg.LOSS.IGNORE_INDEX
    )
)

LOSS_REGISTRY.register(
    "binary_focal_loss", lambda cfg : smp.losses.FocalLoss(mode="binary")
)

LOSS_REGISTRY.register(
    "focal_loss", lambda cfg : smp.losses.FocalLoss(
        mode="multiclass",
        ignore_index=cfg.LOSS.IGNORE_INDEX)
)


@LOSS_REGISTRY.register("ce_region_l1")
def build_ce_region_l1(cfg):
    return CEWithRegionL1(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        temp=cfg.LOSS.TEMP,
        ignore_index=cfg.LOSS.IGNORE_INDEX,
        background_index=cfg.LOSS.BACKGROUND_INDEX,
        weight=torch.FloatTensor(cfg.LOSS.CLASS_WEIGHTS) if cfg.LOSS.CLASS_WEIGHTS else None
    )


@LOSS_REGISTRY.register("ce_region_l2")
def build_ce_region_l2(cfg):
    return CEWithRegionL2(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        temp=cfg.LOSS.TEMP,
        ignore_index=cfg.LOSS.IGNORE_INDEX,
        background_index=cfg.LOSS.BACKGROUND_INDEX,
        weight=torch.FloatTensor(cfg.LOSS.CLASS_WEIGHTS) if cfg.LOSS.CLASS_WEIGHTS else None
    )


@LOSS_REGISTRY.register("bce_logdice_logit")
def build_bce_logdice(cfg):
    return BCEWithDiceWithLogitsLoss(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        log_loss=True
    )


@LOSS_REGISTRY.register("bce_region_kl")
def build_bce_region_kl(cfg):
    return BCEWithRegionKL(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        temp=cfg.LOSS.TEMP
    )


@LOSS_REGISTRY.register("partial_bce_region_kl")
def build_partial_bce_kl(cfg):
    return PartialBCEWithRegionKL(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        temp=cfg.LOSS.TEMP
    )


@LOSS_REGISTRY.register("partial_bce_region_kl_v2")
def build_partial_bce_kl_v2(cfg):
    return PartialBCEWithRegionKLv2(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        temp=cfg.LOSS.TEMP
    )


@LOSS_REGISTRY.register("bce_region_log")
def build_bce_region_log(cfg):
    return BCEWithRegionLog(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        temp=cfg.LOSS.TEMP
    )


@LOSS_REGISTRY.register("bce_region_mse")
def build_bce_region_mse(cfg):
    return BCEWithRegionMSE(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        temp=cfg.LOSS.TEMP
    )


@LOSS_REGISTRY.register("bce_region_mae")
def build_bce_region_mae(cfg):
    return BCEWithRegionMAE(
        alpha=cfg.LOSS.ALPHA,
        factor=cfg.LOSS.ALPHA_FACTOR,
        step_size=cfg.LOSS.ALPHA_STEP_SIZE,
        temp=cfg.LOSS.TEMP
    )


@LOSS_REGISTRY.register("dice_region_kl")
def build_dice_region_kl(cfg):
    return DiceWithRegionKL(cfg.LOSS.ALPHA)


def get_loss_func(cfg) -> nn.Module:
    """get the loss function given the loss name"""
    loss_name = cfg.LOSS.NAME
    loss_func = LOSS_REGISTRY.get(loss_name)(cfg)
    setattr(loss_func, 'name', loss_name)
    logger.info("Successfully build loss func : {}".format(loss_func))

    return loss_func
