# Inspired from https://github.com/w86763777/pytorch-ddpm/tree/master.

# Authors: Kilian Fatras
#          Alexander Tong


import os

import torch
from absl import app, flags
from torchvision import datasets, transforms
from tqdm import trange
import time

from torchcfm.conditional_flow_matching import DGSWPConditionalFlowMatcher
from torchcfm.models.unet.unet import UNetModelWrapper
from torchcfm.utils_cifar import generate_samples, ema, infiniteloop

FLAGS = flags.FLAGS

flags.DEFINE_string("output_dir", "./results/", help="output_directory")
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_float("lr", 2e-4, help="target learning rate")  # TRY 2e-4
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer(
    "total_steps", 40001, help="total training steps"
)  # Lipman et al uses 400k but double batch size
flags.DEFINE_integer("warmup", 5000, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size")  # Lipman et al uses 128
flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", False, help="multi gpu training")
# DGSWP-specific
flags.DEFINE_integer("dgswp_ratio", 10, help="number of batches used to compute the DGSWP transport map")
flags.DEFINE_integer("dgswp_init_steps", 1000, help="number of optim steps to initialize the DGSWP net")
flags.DEFINE_integer("dgswp_steps", 1, help="number of optim steps per batch for the DGSWP net")
flags.DEFINE_float("dgswp_lr", 1e-2, help="DGSWP learning rate")
flags.DEFINE_float("dgswp_sigma", 5e-2, help="DGSWP sigma")

# Evaluation
flags.DEFINE_integer(
    "save_step",
    2000,
    help="frequency of saving checkpoints, 0 to disable during training",
)


if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")


def warmup_lr(step):
    return min(step, FLAGS.warmup) / FLAGS.warmup


def train(argv):
    print(
        "lr, total_steps, ema decay, save_step:",
        FLAGS.lr,
        FLAGS.total_steps,
        FLAGS.ema_decay,
        FLAGS.save_step,
    )

    # DATASETS/DATALOADER
    dataset = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        ),
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=FLAGS.batch_size * FLAGS.dgswp_ratio,
        shuffle=True,
        num_workers=FLAGS.num_workers,
        drop_last=True,
    )

    datalooper = infiniteloop(dataloader)

    # MODELS
    net_model = UNetModelWrapper(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=FLAGS.num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.1,
    ).to(
        device
    )  # new dropout + bs of 128

    # ema_model = copy.deepcopy(net_model)
    optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
    if FLAGS.parallel:
        print(
            "Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory."
        )
        net_model = torch.nn.DataParallel(net_model)
        ema_model = torch.nn.DataParallel(ema_model)

    # show model size
    model_size = 0
    for param in net_model.parameters():
        model_size += param.data.nelement()
    print("Model params: %.2f M" % (model_size / 1024 / 1024))

    #################################
    #            OT-CFM
    #################################

    dgswp_model = torch.nn.Sequential(
        torch.nn.Flatten(),
        torch.nn.Linear(in_features=3 * 32 * 32, out_features=256),
        torch.nn.SELU(),
        torch.nn.Linear(in_features=256, out_features=256),
        torch.nn.SELU(),
        torch.nn.Linear(in_features=256, out_features=1)
    ).to(device)

    FM = DGSWPConditionalFlowMatcher(
        sigma=FLAGS.dgswp_sigma,
        lr=FLAGS.dgswp_lr,
        init_steps=FLAGS.dgswp_init_steps,
        fine_tuning_steps=FLAGS.dgswp_steps,
        model=dgswp_model,
        device=device
    )

    savedir = FLAGS.output_dir + "dgswpcfm/"
    os.makedirs(savedir, exist_ok=True)

    t0 = time.time()
    duration = 0.
    with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar:
        for step in pbar:
            x1 = next(datalooper).to(device)
            x0 = torch.randn_like(x1)

            FM.precompute_map(x0, x1)

            for i_gradient_steps in range(FLAGS.dgswp_ratio):
                optim.zero_grad()

                t, xt, ut = FM.sample_location_and_conditional_flow_from_indices(
                    slice(i_gradient_steps * FLAGS.batch_size, (i_gradient_steps + 1) * FLAGS.batch_size)
                )

                vt = net_model(t, xt)
                loss = torch.mean((vt - ut) ** 2)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip)  # new
                optim.step()
                sched.step()
                ema(net_model, ema_model, FLAGS.ema_decay)  # new

            # sample and Saving the weights
            if FLAGS.save_step > 0 and step % FLAGS.save_step == 0:
                duration += time.time() - t0
                generate_samples(net_model, FLAGS.parallel, savedir, step, net_="normal")
                generate_samples(ema_model, FLAGS.parallel, savedir, step, net_="ema")
                torch.save(
                    {
                        "net_model": net_model.state_dict(),
                        "ema_model": ema_model.state_dict(),
                        "sched": sched.state_dict(),
                        "optim": optim.state_dict(),
                        "step": step,
                        "time": duration
                    },
                    savedir + f"dgswpcfm_cifar10_weights_step_{step}.pt",
                )
                t0 = time.time()


if __name__ == "__main__":
    app.run(train)
