import math

import torch
import torch.nn.functional as F
from torch.nn import Module


class StopPredictionInTop(Module):
    def __init__(self, undesired_class, k1=10, k2=22, top_indexes_to_check=3):
        super(StopPredictionInTop, self).__init__()
        self.undesired_class = undesired_class
        self.k1 = k1
        self.k2 = k2
        self.top_indexes_to_check = top_indexes_to_check

    def forward(self, classification, mse_loss):
        if classification.shape[0] == 1:
            print(classification.softmax(dim=1)[0, self.undesired_class])
        ce_loss = F.cross_entropy(
            classification,
            torch.ones(
                classification.shape[0], device=classification.device, dtype=torch.long
            )
            * self.undesired_class,
            reduction="none",
        )
        top_indices = torch.topk(
            classification, self.top_indexes_to_check, dim=1
        ).indices
        undesired_in_top = torch.any(top_indices == self.undesired_class, dim=1)

        total_loss = torch.zeros_like(mse_loss)
        total_loss[~undesired_in_top] = (
            torch.exp(self.k1 * mse_loss[~undesired_in_top])
            - ce_loss[~undesired_in_top]
        )
        total_loss[undesired_in_top] = torch.log(
            1 + mse_loss[undesired_in_top]
        ) - torch.exp(self.k2 * ce_loss[undesired_in_top])

        return total_loss


class ScaleLoss(Module):
    def __init__(self, mse_scale, ce_scale, klass, move_toward_classification=False):
        super(ScaleLoss, self).__init__()
        self.mse_scale = mse_scale
        self.ce_scale = ce_scale
        self.klass = klass
        self.move_toward_classification = move_toward_classification

    def forward(self, classification, mse_loss):
        ce_loss = F.cross_entropy(
            classification,
            torch.ones(
                classification.shape[0], device=classification.device, dtype=torch.long
            )
            * self.klass,
            reduction="none",
        )
        if classification.shape[0] == 1:
            print(classification.softmax(dim=1)[0, self.klass])

        ce_loss = ce_loss if self.move_toward_classification else -ce_loss
        return torch.exp(self.mse_scale * mse_loss) + self.ce_scale * ce_loss


class ProbabilityScaleLoss(Module):
    def __init__(
        self, mse_scale, ce_scale, klass, prob_scale, move_toward_classification=False
    ):
        super(ProbabilityScaleLoss, self).__init__()
        self.mse_scale = mse_scale
        self.ce_scale = ce_scale
        self.klass = klass
        self.prob_scale = prob_scale
        self.move_toward_classification = move_toward_classification

    def forward(self, classification, mse_loss):
        ce_loss = F.cross_entropy(
            classification,
            torch.ones(
                classification.shape[0], device=classification.device, dtype=torch.long
            )
            * self.klass,
            reduction="none",
        )
        if classification.shape[0] == 1:
            print(classification.softmax(dim=1)[0, self.klass])
        probability_value = self.prob_scale(
            torch.softmax(classification, dim=1)[:, self.klass]
        )
        mse_loss_scale = torch.exp(self.mse_scale * mse_loss)
        return (
            probability_value * mse_loss_scale
            + (1 - probability_value) * self.ce_scale * ce_loss
            if self.move_toward_classification
            else (1 - probability_value) * mse_loss_scale
            - probability_value * self.ce_scale * ce_loss
        )


class SplitSigmoidLoss(Module):
    def __init__(self, n1=2, n2=15, epsilon=0.09):
        super(SplitSigmoidLoss, self).__init__()
        self.epsilon = epsilon
        self.n1 = n1
        self.n2 = n2

    def forward(self, probs):
        return torch.where(
            probs < self.epsilon,
            probs**self.n1 / (probs**self.n1 + (1 - probs) ** self.n1),
            probs**self.n2 / (probs**self.n2 + (1 - probs) ** self.n2),
        )


class SplitLogLoss(Module):
    def __init__(self, n1=6, n2=2, epsilon=0.09, negative_loss=False):
        super(SplitLogLoss, self).__init__()
        self.epsilon = epsilon
        self.n1 = n1
        self.n2 = n2
        self.negative_loss = negative_loss

    def forward(self, loss, ce):
        multiplier = -1 if self.negative_loss else 1
        switch_distance = math.log(self.epsilon**self.n1) - math.log(
            self.epsilon**self.n2
        )
        return (
            torch.where(
                ce < self.epsilon,
                torch.log(loss**self.n1),
                torch.log(loss**self.n2) + switch_distance,
            )
            * multiplier
        )


class LossManipulation(Module):
    def __init__(self, mse_manipulation, ce_manipulation, klass):
        super(LossManipulation, self).__init__()
        self.mse_manipulation = mse_manipulation
        self.ce_manipulation = ce_manipulation
        self.klass = klass

    def forward(self, classification, mse_loss):
        ce_loss = F.cross_entropy(
            classification,
            torch.ones(
                classification.shape[0], device=classification.device, dtype=torch.long
            )
            * self.klass,
            reduction="none",
        )
        total_mse_loss = self.mse_manipulation(mse_loss, ce_loss)
        total_ce_loss = self.ce_manipulation(ce_loss, ce_loss)
        total_loss = total_mse_loss - total_ce_loss
        if classification.shape[0] == 1:
            prob = classification.softmax(dim=1)[0, self.klass]
            print(
                f"Prob: {prob.item():.2f}, Loss: {total_loss.item():.2f}, "
                f"MSE: {mse_loss.item():.2f}, Final MSE: {total_mse_loss.item():.2f}, "
                f"CE: {ce_loss.item():.2f}, Final CE: {total_ce_loss.item():.2f}"
            )
        return total_loss


class LossWithTargetManipulation(Module):
    def __init__(
        self, mse_manipulation, ce_manipulation, target_manipulation, klass, target
    ):
        super(LossWithTargetManipulation, self).__init__()
        self.mse_manipulation = mse_manipulation
        self.ce_manipulation = ce_manipulation
        self.target_manipulation = target_manipulation
        self.klass = klass
        self.target = target

    def forward(self, classification, mse_loss):
        ce_loss = F.cross_entropy(
            classification,
            torch.ones(
                classification.shape[0], device=classification.device, dtype=torch.long
            )
            * self.klass,
            reduction="none",
        )
        target_ce_loss = F.cross_entropy(
            classification,
            torch.ones(
                classification.shape[0], device=classification.device, dtype=torch.long
            )
            * self.target,
            reduction="none",
        )
        total_mse_loss = self.mse_manipulation(mse_loss, ce_loss)
        total_ce_loss = self.ce_manipulation(ce_loss, ce_loss)
        total_target_loss = self.target_manipulation(target_ce_loss, target_ce_loss)
        total_loss = total_mse_loss - total_ce_loss + total_target_loss
        if classification.shape[0] == 1:
            prob = classification.softmax(dim=1)[0, self.klass]
            print(
                f"Prob: {prob.item():.2f}, Loss: {total_loss.item():.2f}, "
                f"MSE: {mse_loss.item():.2f}, Final MSE: {total_mse_loss.item():.2f}, "
                f"CE: {ce_loss.item():.2f}, Final CE: {total_ce_loss.item():.2f}"
                f"Target: {target_ce_loss.item():.2f}, Final Target: {total_target_loss.item():.2f}"
            )
        return total_loss


class LogExpSplitLoss(Module):
    def __init__(self, n1=1, n2=1, epsilon=3.0):
        super(LogExpSplitLoss, self).__init__()
        self.epsilon = epsilon
        self.n1 = n1
        self.n2 = n2

    def forward(self, loss, ce):
        switch_distance = math.exp(-self.epsilon**self.n2) - math.log(
            self.epsilon**self.n1
        )
        return torch.where(
            ce < self.epsilon,
            torch.log(loss**self.n1),
            torch.exp(-(loss**self.n2)) + switch_distance,
        )


class SigmoidScaleLogExpLoss(Module):
    def __init__(
        self,
        a=1e-6,
        b=10,
        log_factor=1,
        exp_factor=1,
        epsilon=3,
        reverse=False,
        negative_loss=False,
    ):
        super(SigmoidScaleLogExpLoss, self).__init__()
        self.a = a
        self.b = b
        self.exp_factor = exp_factor
        self.log_factor = log_factor
        self.epsilon = epsilon
        self.reverse = reverse
        self.negative_loss = negative_loss

    def forward(self, loss, ce):
        multiplier = -1 if self.negative_loss else 1
        s = torch.sigmoid(self.b * (ce - self.epsilon))
        return (
            (1 - s) * (torch.log(loss**self.log_factor + self.a))
            + s * torch.exp(self.exp_factor * loss)
            if self.reverse
            else (1 - s) * torch.exp(self.exp_factor * loss)
            + s * torch.log(loss**self.log_factor + self.a)
        ) * multiplier


class SigmoidScaleLinearExpLoss(Module):
    def __init__(
        self,
        a=1e-6,
        b=10,
        linear_factor=1,
        exp_factor=1,
        epsilon=3,
        reverse=False,
        negative_loss=False,
    ):
        super(SigmoidScaleLinearExpLoss, self).__init__()
        self.a = a
        self.b = b
        self.exp_factor = exp_factor
        self.linear_factor = linear_factor
        self.epsilon = epsilon
        self.reverse = reverse
        self.negative_loss = negative_loss

    def forward(self, loss, ce):
        multiplier = -1 if self.negative_loss else 1
        s = torch.sigmoid(self.b * (ce - self.epsilon))
        return (
            (1 - s) * (loss**self.linear_factor)
            + s * torch.exp(self.exp_factor * loss)
            if self.reverse
            else (1 - s) * torch.exp(self.exp_factor * loss)
            + s * (loss**self.linear_factor + self.a)
        ) * multiplier


class PolynomScaleSplitLoss(Module):
    def __init__(self, b=10, n1=1, n2=1, epsilon=3, multiplier=1):
        super(PolynomScaleSplitLoss, self).__init__()
        self.b = b
        self.n1 = n1
        self.n2 = n2
        self.epsilon = epsilon
        self.multiplier = multiplier

    def forward(self, loss, ce):
        s = torch.sigmoid(self.b * (ce - self.epsilon))
        return self.multiplier * (
            (1 - s) * (loss**self.n1) + s * (-(loss**self.n2))
        )


class ZeroLoss(Module):
    def forward(self, *args, **kwargs):
        return torch.tensor(0.0)


class MultiplierLoss(Module):
    def __init__(self, multiplier):
        super(MultiplierLoss, self).__init__()
        self.multiplier = multiplier

    def forward(self, loss, *args, **kwargs):
        return self.multiplier * loss
