#!/usr/bin/env python3
import torch

from diffusers import DiffusionPipeline


class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
    def __init__(self, unet, scheduler):
        super().__init__()

        self.register_modules(unet=unet, scheduler=scheduler)

    def __call__(self):
        image = torch.randn(
            (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
        )
        timestep = 1

        model_output = self.unet(image, timestep).sample
        scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample

        result = scheduler_output - scheduler_output + torch.ones_like(scheduler_output)

        return result
