import torch
import torch.nn as nn
from torch_scatter import scatter_mean
from core.losses.loss_utils import sample_zero_centered_gaussian, OT_path, VP_path


class DynamicsLossBase(nn.Module):
    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction

    def forward(self, t, dx_t, dh_t, z_x, z_h, x, h, segment_ids, *args, **kwargs):

        raise NotImplementedError


class Discrete_Diffusion_loss(DynamicsLossBase):


    def __init__(self, scheduler, reduction="mean", timesteps=1000):
        super().__init__(reduction=reduction)
        self.scheduler = scheduler
        self.timesteps = timesteps

    def zeroth_term_loss(
        self, t, dx_t, dh_t, z_x, z_h, x, h, segment_ids, *args, **kwargs
    ):
        raise NotImplementedError

    def forward(self, t, dx_t, dh_t, z_x, z_h, x, h, segment_ids, *args, **kwargs):

        t_int = torch.round(t * self.timesteps).float()
        s_int = (
            t_int - 1
        )  # s is the previous time step, here we consider the interval as 1.
        t_is_zero = (t_int == 0).float()
        t_is_not_zero = 1 - t_is_zero

        gamma_t = self.scheduler(t_int)  # -log SNR(t) = -log(1-sigma^2/sigma^2)
        gamma_s = self.scheduler(s_int)  # -log SNR(s)

        SNR_weight = torch.exp(-gamma_s) - torch.exp(-gamma_t)  # SNR(t)/SNR(s)

        loss = (
            SNR_weight
            * [
                torch.sum((dx_t - z_x) ** 2, dim=-1)
                + torch.sum((dh_t - z_h) ** 2, dim=-1)
            ]
            * t_is_not_zero
        )

        loss = scatter_mean(loss, segment_ids, dim=0)

        if self.reduction == "mean":
            loss = loss.mean()

        return loss



class Continous_Diffusion_loss(DynamicsLossBase):


    def __init__(self, reduction="mean"):
        super().__init__(reduction=reduction)
        # self.l2 =

    def forward(self, t, dx_t, dh_t, z_x, z_h, x, h, segment_ids, *args, **kwargs):
        return super().forward(
            t, dx_t, dh_t, z_x, z_h, x, h, segment_ids, *args, **kwargs
        )


class FM_loss(DynamicsLossBase):


    def __init__(
        self,
        probability_path_x=OT_path(),
        probability_path_h=VP_path(),
        reduction: str = "mean",
    ):
        super().__init__(reduction=reduction)
        self.p_x = probability_path_x  # function to get the vector field of the probability path on x
        self.p_h = probability_path_h  # function to get the vector field of the probability path on h

        # self.l2 =

    def forward(self, t, dx_t, dh_t, z_x, z_h, x, h, segment_ids, *args, **kwargs):
        target_x_field = self.p_x.target_field(z_x, x, t)
        target_h_field = self.p_h.target_field(z_h, h, t)
        loss = torch.sum((dx_t - target_x_field) ** 2, dim=-1) + torch.sum(
            (dh_t - target_h_field) ** 2, dim=-1
        )
        loss = scatter_mean(loss, segment_ids, dim=0)

        # if self.reduction == "mean":
        #     loss = loss.mean()

        return loss
