from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
import torch
model = Unet(
    dim = 64,
    dim_mults = (1,2, 4, 8),
    flash_attn = True
)
#model.load_state_dict(torch.load("./results/model-26.pt")['model'])
print(sum([p.numel() for p in model.parameters()]))
diffusion = GaussianDiffusion(
    model,
    image_size = 96,
    timesteps = 1000,           # number of steps
    sampling_timesteps = 250    # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)

trainer = Trainer(
    diffusion,
    './data/stl10',
    train_batch_size = 64,
    train_lr = 8e-5,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True,                       # turn on mixed precision
    calculate_fid = True,
    save_best_and_latest_only=False, # whether to calculate fid during training
    results_folder = './ddpm-stl-96/denoising_lib',
    num_fid_samples=512,# number of samples for calculating fid
    save_and_sample_every = 3000,
)
trainer.load(53)
trainer.train()