# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
A minimal training script for DiT using PyTorch DDP.
"""
import os
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from torch.utils.data import DataLoader
from collections import OrderedDict
from copy import deepcopy
from diffusers.models import AutoencoderKL

from models.model import DiT_models
from utils.config import get_args
from models.targets_ST import get_targets_ST, loss_ST
from models.targets_ST_CSL import get_targets_ST_CSL, loss_ST_CSL
from data_loaders.data_loader import CelebAHQ256

#################################################################################
#                             Training Helper Functions                         #
#################################################################################

@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


def main(args):

    assert torch.cuda.is_available()
    torch.manual_seed(args.seed)

    assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
    args.latent_size = args.image_size // 8
    model = DiT_models[args.model](input_size=args.latent_size, num_classes=1).to(args.device)

    ema = deepcopy(model).to(args.device)  # Create an EMA of the model for use after training
    requires_grad(ema, False)
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(args.device)
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)


    train_dataset = CelebAHQ256(image_folder="./dataset/CelebAHQ256/")
    train_loader = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
            drop_last=True
        )
    ###################################-------------------------------------------

    # Prepare models for training:
    update_ema(ema, model, decay=0)  # Ensure EMA is initialized with synced weights
    model.train()  # important! This enables embedding dropout for classifier-free guidance
    ema.eval()  # EMA model should always be in eval mode


    if args.train_type == 'ST':
        get_targets_fn = get_targets_ST
        loss_fn = loss_ST
    elif args.train_type == 'ST-CSL':
        get_targets_fn = get_targets_ST_CSL
        loss_fn = loss_ST_CSL


    for epoch in range(1, args.epochs+1):
        batch_counter = 0
        loss_dict = {"train_loss": 0, "shortcut_loss": 0, "uniformity_loss": 0}
        for x, y in train_loader:
            batch_counter += 1
            x = x.to(args.device)
            y = y.to(args.device)
            with torch.no_grad():
                # Map input images to latent space + normalize latents:
                x = vae.encode(x).latent_dist.sample().mul_(0.18215)
                inputs = get_targets_fn(args, model, x, y)

            opt.zero_grad()
            model.train()
            loss = loss_fn(args, inputs, model, loss_dict)
            loss.backward()
            opt.step()
            update_ema(ema, model)

        for loss_type in loss_dict:
            loss_dict[loss_type] = loss_dict[loss_type] / batch_counter
        print(f"Epoch = {epoch}; ", loss_dict)

        if epoch % args.ckpt_every == 0:
            checkpoint = {
                "model": model.state_dict(),
                "ema": ema.state_dict(),
                "opt": opt.state_dict(),
                "args": args
            }

            checkpoint_path = f"{args.checkpointDir}/model_{args.model}_epoch-{epoch}_bst-every-{args.bootstrap_every}.pt"
            torch.save(checkpoint, checkpoint_path)


args = get_args()
args.device = torch.device(f"cuda:{args.gpu}")
args.checkpointDir = f"checkpoints/{args.train_type}"
os.makedirs(args.checkpointDir, exist_ok=True)
main(args)


