from typing import List

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from src.metrics.utils.losses import sigmoid_focal_loss
from torchmetrics import Metric, MeanSquaredError, MetricCollection
import time


class HaltingLoss(nn.Module):
    """
    This loss uses the binary cross entropy loss to match the halt signal
    """

    def __init__(self, use_focal_loss: bool = False):
        super().__init__()
        self.use_focal_loss = use_focal_loss

    def forward(
            self,
            pred_halt_logits: Tensor,
            true_halt_signal: Tensor,
            reduce: bool=True, ret_log: bool=False
        ):

        if not self.use_focal_loss:
            halt_signal_loss = F.binary_cross_entropy_with_logits(
                input =		pred_halt_logits,
                target =	true_halt_signal.float(),
                reduction =	'mean' if reduce else 'none'
            )
        else:
            halt_signal_loss = sigmoid_focal_loss(
                input =		pred_halt_logits,
                target =	true_halt_signal.float(),
                reduction =	'mean' if reduce else 'none',
                alpha = -1,
                gamma = 1
            )

        if ret_log:
            to_log = {
                "train_halting/halt_signal_loss": halt_signal_loss.detach(),
            }
            return halt_signal_loss, to_log

        else:
            return halt_signal_loss

class DistributionReinsertionLoss(nn.Module):
    """
    This loss uses KL-divergence to match a distribution
    on the number of added nodes to the true distribution
    """

    def forward(
            self,
            pred_params: Tensor,
            true_params: Tensor,
            reduce: bool=True, ret_log: bool=False
        ):
        """compute the KL-divergence between the predicted and true
        distributions of the number of added nodes

        Parameters
        ----------
        pred_params : List[Tensor]
            _description_
        true_params : List[Tensor]
            true params are expected to be standard probabilities
        reduce : bool, optional
            _description_, by default True
        ret_log : bool, optional
            _description_, by default False

        Returns
        -------
        _type_
            _description_
        """

        pred_logits = pred_params
        true_probs = true_params

        # convert softmax weights to log probabilities (needed for KL-divergence)
        pred_logprobs = torch.log_softmax(pred_logits, dim=-1)


        kl_div: Tensor = F.kl_div(
            input = 	pred_logprobs,
            target = 	torch.clip(true_probs, min=1e-8), # account for 0 probabilities
            reduction = 'batchmean' if reduce else 'none'
        )

        # kl_div: Tensor = F.cross_entropy(
        #     input = 	pred_logits,
        #     target = 	true_probs.argmax(dim=-1),
        #     reduction = 'mean' if reduce else 'none'
        # )

        if not reduce:
            kl_div = kl_div.sum(dim=-1)

        total_loss = kl_div

        if ret_log:
            to_log = {
                "train_reinsertion/kl_div_nodes_dist": kl_div.detach(),
            }
            return total_loss, to_log

        else:
            return total_loss

class RegressionReinsertionLoss(nn.Module):
    """
    This variant only uses the MSE loss for predicting the correct number
    of remaining nodes
    """

    def forward(
            self,
            pred_params: List[Tensor],
            true_params: List[Tensor],
            reduce: bool=True, ret_log: bool=False
        ):

        pred_remaining_nodes = pred_params
        true_remaining_nodes = true_params

        support_term: Tensor = F.mse_loss(
            input = 	pred_remaining_nodes,
            target = 	true_remaining_nodes,
            reduction = 'mean' if reduce else 'none'
        )

        total_loss = support_term

        if ret_log:
            to_log = {
                "train_reinsertion/reins_support_term": support_term.detach(),
            }
            return total_loss, to_log

        else:
            return total_loss

class ReinsertionLoss(nn.Module):
    """
    pred params:
        pred_logit = l
        pred_mean = mu
    true params:
        true_prob = q
        true_num_experiments = n0 - nt
    KL divergence term:
        Full formula:
            KL_div = (n0 - nt)[softplus(l) - q*l + q*log(q) + (1-q)*log(q)]
        Remove constant terms:
            KL_div = (n0 - nt)[softplus(l) - q*l]
    
    Support prediction term:
        L_support = (mu / sigmoid(l) - n0 + nt)^2
        notice: this is the MSE loss!
    Total:
        L_reintegration = KL_div + lambda * L_support
    """

    def __init__(self, lambda_train: float = 1.):
        super().__init__()
        self.lambda_train = lambda_train


    def forward(
            self,
            pred_logit: Tensor, pred_mean: Tensor,
            true_prob: Tensor, true_num_experiments: Tensor,
            reduce: bool=True, ret_log: bool=False
        ):

        # KL_div = (n0 - nt)[softplus(l) - q*l]
        true_not_prob = 1. - true_prob
        differentiable_term = F.softplus(pred_logit) - true_prob * pred_logit
        corrective_term = true_prob * torch.log(true_prob+1e-7) + true_not_prob * torch.log(true_not_prob+1e-7)

        kl_div_term: Tensor = true_num_experiments * (differentiable_term + corrective_term)
        if reduce:
            kl_div_term = torch.mean(kl_div_term)

        pred_num_experiments = pred_mean / (torch.sigmoid(pred_logit) + 1e-6)

        # L_support = (mu / sigmoid(l) - n0 + nt)^2
        support_term: Tensor = F.mse_loss(
            input = 	pred_num_experiments,
            target = 	true_num_experiments,
            reduction = 'mean' if reduce else 'none'
        )

        # L_reintegration = KL_div + lambda * L_support
        total_loss: Tensor = kl_div_term + self.lambda_train * support_term

        if ret_log:
            to_log = {
                "train_reinsertion/reins_kl_div_term": kl_div_term.detach(),
                "train_reinsertion/reins_support_term": support_term.detach(),
                "train_reinsertion/total_loss": total_loss.detach(),
            }
            return total_loss, to_log

        else:
            return total_loss


class ReinsertionLoss2(nn.Module):
    """
    pred params:
        pred_logit = l
        pred_ratio = b => N = nt((1-b)/b)
    true params:
        true_prob = q
        true_num_experiments = n0 - nt => n0 = nt((1-a)/a)
    KL divergence term:
        Full formula:
            KL_div = (n0 - nt)[softplus(l) - q*l + q*log(q) + (1-q)*log(q)]
        Remove constant terms:
            KL_div = [softplus(l) - q*l]
    
    Support prediction term:
        Full formula:
            L_support = (nt(b/(1-b)) - nt(a/(1-a)))^2
        Remove constant terms:
            L_support = (b/(1-b)) - nt(a/(1-a)))^2
    Total:
        L_reintegration = KL_div + lambda * L_support
    """

    def __init__(self, lambda_train = 1.):
        super().__init__()
        self.lambda_train = lambda_train


    def forward(
            self,
            pred_logit: Tensor, pred_ratio_logit: Tensor, pred_stop_logit: Tensor,
            true_prob: Tensor, true_ratio: Tensor, should_stop: Tensor,
            reduce: bool=True, ret_log: bool=False
        ):

        should_stop = should_stop.float()

        should_not_stop = 1. - should_stop

        # KL_div = [softplus(l) - q*l]
        true_not_prob = 1. - true_prob
        differentiable_term = F.softplus(pred_logit) - true_prob * pred_logit
        corrective_term = true_prob * torch.log(true_prob+1e-7) + true_not_prob * torch.log(true_not_prob+1e-7)

        kl_div_term: Tensor = differentiable_term + corrective_term

        # L_support = (mu / sigmoid(l) - n0 + nt)^2
        support_term: Tensor = F.mse_loss(
            input = 	pred_ratio_logit,
            target = 	torch.logit(true_ratio, eps=1e-4),
            reduction = 'none'
        )
        support_term = support_term * should_not_stop

        should_stop_term: Tensor = F.binary_cross_entropy_with_logits(
            input =		pred_stop_logit,
            target =	should_stop,
            reduction =	'mean' if reduce else 'none'
        )

        if reduce:
            kl_div_term = torch.mean(kl_div_term)
            support_term = support_term.sum() / (should_not_stop.sum() + 1e-6)


        # L_reintegration = KL_div + lambda * L_support
        total_loss: Tensor = kl_div_term + self.lambda_train * support_term + should_stop_term

        if ret_log:
            to_log = {
                "train_reinsertion/reins_kl_div_term": kl_div_term.detach(),
                "train_reinsertion/reins_support_term": support_term.detach(),
                "train_reinsertion/should_stop_term": should_stop_term.detach(),
                "train_reinsertion/total_loss": total_loss.detach(),
            }
            return total_loss, to_log

        else:
            return total_loss