import torch
import torch.nn.functional as F
from denflow.utils_closed_form import get_full_velocity_field


class OPTIMAL_DENOISER(torch.nn.Module):

    def __init__(self, device, args):
        super().__init__()
        self.d = args.dim_image
        self.num_channels = args.num_channels
        self.device = device
        self.args = args
        self.sigmamin = 0.0
        self.time_start = 0.0
        self.time_end = 1.0

    def prepare_optimal_denoiser(self, data_loaders):
        train_loader = data_loaders['train']
        x1 = next(iter(train_loader))[0]
        self.x1 = x1.to(self.device)  # torch.flatten(x1, start_dim=1)

    def get_denoiser(self, xt, t):
        """
        xt: [B,C,H,W]
        t : [B]
        """
        B = xt.shape[0]
        t_b = t.view(B, *([1] * (xt.ndim - 1)))  # [B,1,1,1] for broadcasting
        return xt + (1 - t_b) * self.get_velocity(xt, t)

    def get_velocity(self, xt, t):
        return get_full_velocity_field(t, xt, x1prime=self.x1)
