from functools import partial

import torch
from torch import nn
from diffusers import UNet2DModel


class Wrapper(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        pred = self.model(x, t)
        return pred.sample


def celeba1_model(src: str) -> nn.Module:
    model = UNet2DModel.from_pretrained(src)
    model = Wrapper(model)
    return model


def main2():
    from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline

    model_id = "google/ddpm-ema-celebahq-256"

    # load model and scheduler
    ddpm = DDIMPipeline.from_pretrained(
        model_id

        )  # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference

    # run pipeline in inference (sample random noise and denoise)
    image = ddpm().images[0]

    # save image
    image.save("ddpm_generated_image.png")


def main():
    model_id = "google/ddpm-ema-celebahq-256"
    model_src = celeba1_model(model_id)

    x = torch.randn([2, 3, 256, 256])
    t = torch.randint(0, 1000, (2,), dtype=torch.long)
    pred = model_src(x, t)

    print(pred.shape)


if __name__ == '__main__':
    main()
