import torch
from torch import nn, Tensor


class DiffusionClassifierSingleHeadBase(nn.Module):
    def __init__(
        self,
        unet: nn.Module,
        device=torch.device("cuda"),
        transform=lambda x: (x - 0.5) * 2,
        num_classes=10,
        target_class=None,
    ):
        super().__init__()
        self.unet = unet
        self.device = device
        self.transform = transform
        self._init()
        self.target_class = target_class if target_class is not None else list(range(num_classes))
        self.num_classes = num_classes

    def _init(self):
        self.eval().requires_grad_(False)
        self.to(self.device)

    def get_one_instance_prediction(self, x: Tensor) -> Tensor:
        """
        :param x: 1, C, H, D
        :return D
        """
        loss = []
        for class_id in self.target_class:
            loss.append(self.unet_loss_without_grad(x, class_id))
        loss = torch.tensor(loss, device=self.device)
        loss = loss * -1  # convert into logit where greatest is the target
        return loss

    def forward(self, x: Tensor) -> Tensor:
        xs = x.split(1)  # 1, C, H, D
        y = []
        for now_x in xs:
            y.append(self.get_one_instance_prediction(now_x))
        y = torch.stack(y)  # N, num_classes
        return y

    def unet_loss_without_grad(self, x: Tensor, y: int or Tensor = None) -> float:
        pass
