from ..EDMDC import EDMEulerIntegralDC
from torch import nn, Tensor
import torch

__all__ = ["NoisedEDMEulerIntegralDC", "NoisedEDMEulerIntegralCondDenoiseDC"]


class NoisedEDMEulerIntegralDC(EDMEulerIntegralDC):
    def __init__(self, uncond_edm: nn.Module, cond_edm: nn.Module, sigma=0.25, *args, **kwargs):
        self.sigma = sigma * 2
        super(NoisedEDMEulerIntegralDC, self).__init__(
            cond_edm, timesteps=torch.linspace(self.sigma, 3, 1001), *args, **kwargs
        )
        self.uncond_edm = uncond_edm

    def get_one_instance_prediction(self, x: Tensor) -> Tensor:
        """
        :param x: 1, C, H, D
        :return D
        """
        x = self.one_step_denoise(x)
        return super().get_one_instance_prediction(x)

    def one_step_denoise(self, x: Tensor, normalize=True, sigma=0.5, y=None) -> Tensor:
        """
        x: In range (0, 1)
        """
        x = (x - 0.5) * 2 if normalize else x
        x0 = self.uncond_edm(x, torch.zeros((x.shape[0],), device=x.device) + sigma, y)
        x0 = x0 / 2 + 0.5 if normalize else x0
        return x0


class NoisedEDMEulerIntegralCondDenoiseDC(EDMEulerIntegralDC):
    def __init__(self, uncond_edm: nn.Module, cond_edm: nn.Module, sigma=0.25, *args, **kwargs):
        self.sigma = sigma * 2
        super(NoisedEDMEulerIntegralCondDenoiseDC, self).__init__(
            cond_edm, timesteps=torch.linspace(self.sigma, 3, 1001), *args, **kwargs
        )
        self.uncond_edm = uncond_edm
        self.eval().requires_grad_(False)

    def get_one_instance_prediction(self, x: Tensor) -> Tensor:
        """
        :param x: 1, C, H, D
        :return D
        """
        x = x.repeat(self.num_classes, 1, 1, 1)
        x = self.one_step_denoise(x, y=torch.arange(self.num_classes, device=self.device))
        loss = []
        for class_id in self.target_class:
            loss.append(self.unet_loss_without_grad(x[None, class_id, :, :, :], class_id))
        loss = torch.tensor(loss, device=self.device)
        loss = loss * -1
        return loss
