from args import parse_args
from models import initialize_models
from data import prepare_dataloader
from sampling import distill_dataset
from accelerate import Accelerator
import wandb
import torch
from torch.optim import AdamW
from transformers import get_scheduler
from labml import monit


def main(args):
    accelerator = Accelerator()
    if accelerator.is_main_process:
        wandb_run = wandb.init(project=args.wandb_project, name=args.wandb_name, config={'guidance_step_size': args.guidance_step_size, 'time_travel': args.time_travel, 'batch_size': args.batch_size, 'pretrained_model_name_or_path': args.pretrained_model_name_or_path, 'image_size': args.image_size, 'scheduler': 'DDIMScheduler'})
    try:
        sampler = initialize_models(args, accelerator, wandb_run)
        with accelerator.main_process_first(), monit.section('Prepare Dataloader', is_silent=not accelerator.is_main_process):
            train_dataloader = prepare_dataloader(args, accelerator)
        sampler, dataloader = accelerator.prepare(sampler, train_dataloader)
        distill_dataset(sampler=sampler, train_dataloader=dataloader, accelerator=accelerator, args=args)
    except Exception as e:
        if accelerator.is_main_process:
            wandb.finish()
        raise e


if __name__ == '__main__':
    args = parse_args()
    main(args)