from abc import ABC, abstractmethod

import torch
import torch.nn as nn

from ..base.mixin import PropertyMixIn


class Diffusion(ABC, PropertyMixIn, nn.Module):

    def __init__(self, model: nn.Module, **kwargs):
        super().__init__()
        self.model = model

    @abstractmethod
    def sample(self, *args, **kwargs):
        pass

    def round_sigma(self, sigma):
        return torch.as_tensor(sigma)
