#!/usr/bin/env python
import argparse
import os

import torch
from guided_diffusion import logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (add_dict_to_argparser, args_to_dict,
                                          create_gaussian_diffusion)

from tada.augment_pipes import EdmAugmentPipe
from tada.dist_util import setup_dist
from tada.image_datasets import load_data
from tada.script_util import (augment_defaults, create_model_and_diffusion,
                              model_and_diffusion_defaults, save_config,
                              save_env_vars, set_seed)
from tada.train_util import TrainLoop


def main():
    args = create_argparser().parse_args()

    device = setup_dist()
    logger.configure(dir=args.log_dir)
    args.seed = set_seed(args.seed)
    save_config(args, os.path.join(logger.get_dir(), "config.json"))
    save_env_vars(os.path.join(logger.get_dir(), "env.json"))

    torch.backends.cuda.matmul.allow_tf32 = args.use_tf32
    torch.backends.cudnn.allow_tf32 = args.use_tf32
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys()),
        augment_class=args.augment_class,
    )
    model.to(device)
    schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)

    if args.sample_steps == "":
        sample_diffusion = None
    else:
        sample_diffusion = create_gaussian_diffusion(
            steps=args.diffusion_steps,
            learn_sigma=args.learn_sigma,
            noise_schedule=args.noise_schedule,
            use_kl=args.use_kl,
            predict_xstart=args.predict_xstart,
            rescale_timesteps=args.rescale_timesteps,
            rescale_learned_sigmas=args.rescale_learned_sigmas,
            timestep_respacing=args.sample_steps,
        )

    logger.log("creating data loader...")
    augment_pipe_kwargs = args_to_dict(args, augment_defaults().keys())
    if args.augment_class == "edm":
        augment_pipe = EdmAugmentPipe(**augment_pipe_kwargs)
        args.augment_class = "none"
    else:
        augment_pipe = None
    data = load_data(
        data_path=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        diffusion=diffusion,
        schedule_sampler=schedule_sampler,
        class_cond=args.class_cond,
        random_flip=args.random_flip,
        seed=args.seed,
        augment_class=args.augment_class,
        augment_pipe_kwargs=augment_pipe_kwargs,
    )

    logger.log("training...")
    TrainLoop(
        model=model,
        diffusion=diffusion,
        data=data,
        batch_size=args.batch_size,
        microbatch=args.microbatch,
        lr=args.lr,
        ema_rate=args.ema_rate,
        log_interval=args.log_interval,
        save_interval=args.save_interval,
        resume_checkpoint=args.resume_checkpoint,
        use_fp16=args.use_fp16,
        fp16_scale_growth=args.fp16_scale_growth,
        schedule_sampler=schedule_sampler,
        weight_decay=args.weight_decay,
        lr_anneal_steps=args.lr_anneal_steps,
        total_steps=args.total_steps,
        sample_diffusion=sample_diffusion,
        augment_pipe=augment_pipe,
    ).run_loop()
    return args.log_dir


def create_argparser():
    defaults = dict(
        data_dir="",
        schedule_sampler="uniform",
        lr=1e-4,
        weight_decay=0.0,
        lr_anneal_steps=0,
        batch_size=64,
        microbatch=-1,
        ema_rate="0.9999",  # comma-separated list of EMA values
        log_interval=200,
        save_interval=10000,
        resume_checkpoint="",
        use_fp16=False,
        fp16_scale_growth=1e-3,
        log_dir="./logs",
        total_steps=300000,
        cudnn_benchmark=True,
        random_flip=True,
        use_tf32=False,
        seed=None,
        sample_steps="",
        augment_class="tada",
    )
    defaults.update(model_and_diffusion_defaults())
    defaults.update(augment_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()
