from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
import torch
import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def main():
     # Define the U-Net model for diffusion
    model = Unet(
        dim=128,
        channels=3,
        dim_mults=(1, 2, 4, 8)
    )

    # Define the Gaussian Diffusion process
    diffusion = GaussianDiffusion(
        model,
        image_size=128,
        timesteps=1000,
        sampling_timesteps=100,
        auto_normalize=False,
    )

    # Load the trained model
    model_path = ''
    model_weights = torch.load(model_path)

    # Assign the trained weights to the model
    diffusion.load_state_dict(model_weights['model'])

    # Move to device and set to evaluation mode
    diffusion = diffusion.to(device)
    diffusion.eval()

    # Draw 601 samples in batches of 75
    samples_list = []
    for i in range(8):
        with torch.no_grad():
            samples = diffusion.sample(75)
        samples_list.append(samples.detach())
        print(f"Generated {75*(i+1)} samples")
    # Generate one last batch of 1 sample
    samples = diffusion.sample(1)
    samples_list.append(samples)
    # Concatenate all the samples
    samples = torch.cat(samples_list, dim=0)
    num_samples = samples.shape[0]
    print(f"Generated {num_samples} samples of shape {samples.shape[1:]}")

    # Save the samples to a folder
    sample_folder = ''
    os.makedirs(sample_folder, exist_ok=True)
    torch.save(samples, os.path.join(sample_folder, "ddpm_samples.pt"))

if __name__ == '__main__':
    main()