"""An interface for score-based models."""

import torch
import math


class SDEModelFilter(torch.nn.Module):
    """A class describing all the required functionalities for DPF.

    This class should be able to compute:
        1. The drift function of the SDE
        2. The diffusion function of the SDE
        3. The score function matching the density
        4. An initial distribution.
    """

    def __init__(self):
        super(SDEModelFilter, self).__init__()

    def set_functions(self, drift, diffusion, bwd_drift_function, init_dist, obs_noise, device='cpu'):
        """Set the relevant functions.
        
        Use when using this class directly instead of inheriting it.
        """
        self.drift = drift
        self.diffusion = diffusion
        self.bwd_drift_function = bwd_drift_function
        self.init_dist = init_dist
        self.obs_noise = obs_noise
        self.device = device

    def eval_drift(self, t, x):
        if t.ndim == 0:
            batch_size = x.shape[0]
            t = t.unsqueeze(0).repeat(batch_size, 1)
        return self.drift(x, t).to(x.dtype) # this is ugly, looks like I coded (t, x) in new code, (x, t) in old code

    def eval_diffusion(self, t):
        return self.diffusion(t).to(t.dtype)

    def eval_obs_noise(self, t):
        return self.obs_noise(t).to(t.dtype)
    
    def eval_bwd_drift(self, t, x):
        if t.ndim == 0:
            batch_size = x.shape[0]
            t = t.unsqueeze(0).repeat(batch_size, 1)
        return self.bwd_drift_function(x, t)
        
    def generate_init_samples(self, n_samples):
        return self.init_dist.sample((n_samples,))
