import os
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
os.environ["CUDA_VISIBLE_DEVICES"] = str(2)   # device

def main():
    model = Unet(
        channels = 1,
        dim = 64,
        self_condition = False,
        dim_mults = (1, 2, 4, 8),
        flash_attn = True
    )

    diffusion = GaussianDiffusion(
        model,
        image_size = 256,
        timesteps = 1000,           # number of steps
        sampling_timesteps = 250,    # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
        auto_normalize = False,     # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
    )

    trainer = Trainer(
        diffusion, # GaussianDiffusion
        trn_npz = './data/trn/',
        train_batch_size = 12,
        train_lr = 8e-5,
        train_num_steps = 500000,         # total training steps
        gradient_accumulate_every = 2,    # gradient accumulation steps
        ema_decay = 0.995,                # exponential moving average decay
        amp = True,                       # turn on mixed precision
        calculate_fid = False,              # whether to calculate fid during training
        save_and_sample_every = 1000,
        results_folder = './results/',
        num_samples = 9, 
        tst_npz = './data/val/',
        test_batch_size = 12
    )
    
    trainer.train()

if __name__ == '__main__':
    main()
