"""
Training script for the Generator Matching framework.
"""
import hydra
from omegaconf import DictConfig, OmegaConf

import copy
from copy import deepcopy
import logging
import os
from pathlib import Path
from collections import OrderedDict
import json

import torch
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from accelerate import DistributedDataParallelKwargs
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed

from t2i.models.mmdit import MMDiT
from loss import SILoss_GM
from t2i.utils import load_encoders, preprocess_raw_image

from i2t.model.cond_transformer import Transformer

from dataset import MSCOCO256FeaturesGM
from sampling_gm import euler_sampler
from transformers import CLIPTokenizer, CLIPTextModel
from diffusers.models import AutoencoderKL
import wandb
import math
from torchvision.utils import make_grid

logger = get_logger(__name__)

VAE_PATH = "*********"


def array2grid(x):
    nrow = round(math.sqrt(x.size(0)))
    x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
    x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    return x


@torch.no_grad()
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
    device = moments.device
    mean, logvar = torch.chunk(moments, 2, dim=1)
    logvar = torch.clamp(logvar, -30.0, 20.0)
    std = torch.exp(0.5 * logvar)
    z = mean + std * torch.randn_like(mean)
    z = (z * latents_scale + latents_bias) 
    return z 


@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        name = name.replace("module.", "")
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    logging.basicConfig(
        level=logging.INFO,
        format='[\033[34m%(asctime)s\033[0m] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
    )
    logger = logging.getLogger(__name__)
    return logger


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


#################################################################################
#                                  Training Loop                                #
#################################################################################
@hydra.main(config_path="configs", config_name="gm_config")
def main(cfg: DictConfig): 
    # -------------------------------   
    # set accelerator
    # -------------------------------
    logging_dir = Path(cfg.output_dir, cfg.logging_dir)
    accelerator_project_config = ProjectConfiguration(
        project_dir=cfg.output_dir, logging_dir=logging_dir
    )

    ddp_kwargs = DistributedDataParallelKwargs(
        find_unused_parameters=True,
        gradient_as_bucket_view=True
    )
    accelerator = Accelerator(
        gradient_accumulation_steps=cfg.accelerator.gradient_accumulation_steps,
        mixed_precision=cfg.accelerator.mixed_precision,
        log_with=cfg.accelerator.report_to,
        project_config=accelerator_project_config,
        kwargs_handlers=[ddp_kwargs],
    )

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # -------------------------------
    # set seed and create output directory
    # -------------------------------
    if accelerator.is_main_process:
        os.makedirs(cfg.output_dir, exist_ok=True)  # Make results folder (holds all experiment subfolders)
        save_dir = os.path.join(cfg.output_dir, cfg.exp_name)
        os.makedirs(save_dir, exist_ok=True)
        OmegaConf.save(cfg, "config.yaml")
        checkpoint_dir = f"{save_dir}/checkpoints"  # Stores saved model checkpoints
        os.makedirs(checkpoint_dir, exist_ok=True)
        logger = create_logger(save_dir)
        logger.info(f"Experiment directory created at {save_dir}")
    device = accelerator.device
    if torch.backends.mps.is_available():
        accelerator.native_amp = False
    if cfg.seed is not None:
        set_seed(cfg.seed + accelerator.process_index)

    # -------------------------------
    # Import encoder
    # -------------------------------
    assert cfg.t2i.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
    latent_size = cfg.t2i.resolution // 8

    if cfg.t2i.enc_type != 'None':
        encoders, encoder_types, architectures = load_encoders(cfg.t2i.enc_type, device, cfg.t2i.enc_path)
    else:
        encoders, encoder_types, architectures = [None], [None], [None]
    z_dims = [encoder.embed_dim for encoder in encoders] if cfg.t2i.enc_type != 'None' else [0]

    encoder = encoders[0]
    encoder_type = encoder_types[0]
    encoder.to(device, dtype=weight_dtype)

    # -------------------------------
    # Load Clip Model 
    # -------------------------------
    tokenizer = CLIPTokenizer.from_pretrained(cfg.i2t.clip_path)
    clip_model = CLIPTextModel.from_pretrained(cfg.i2t.clip_path, use_safetensors=False)
    clip_model = clip_model.to(device)
    clip_model.eval()
    clip_model.requires_grad_(False)
    clip_model.to(dtype=weight_dtype)

    vocab_size = tokenizer.vocab_size

    if cfg.i2t.p0_dist == "masked":
        tokenizer.add_tokens(["<MASK>"])
        clip_model.resize_token_embeddings(len(tokenizer))
        mask_token_id = tokenizer.convert_tokens_to_ids("<MASK>")

        # Get embedding layer
        embedding_layer = clip_model.get_input_embeddings()

        # Set requires_grad = True only for <MASK>
        embedding_layer.weight[mask_token_id].requires_grad = True

    # -------------------------------
    # Define T2I model  
    # -------------------------------
    model_t2i = MMDiT(
        input_size=latent_size,
        z_dims=z_dims,
        encoder_depth=cfg.t2i.encoder_depth,
    )

    model_t2i = model_t2i.to(device, dtype=weight_dtype)
    ema_t2i = deepcopy(model_t2i).to(device)  # Create an EMA of the model for use after training
    requires_grad(ema_t2i, False)

    if accelerator.is_main_process:
        logger.info(f"MMDiT Parameters: {sum(p.numel() for p in model_t2i.parameters()):,}")

    # Load pre-trained model 
    if cfg.t2i.pretrained_model_path is not None:
        ckpt = torch.load(cfg.t2i.pretrained_model_path, map_location='cpu')
        model_t2i.load_state_dict(ckpt['model'])

        if accelerator.is_main_process:
            logger.info(f"Loaded pre-trained MMDiT model from {cfg.t2i.pretrained_model_path}")

    # -------------------------------
    # Define I2T model  
    # -------------------------------
    model_text = Transformer(
        config=cfg.i2t.params,
        vocab_size=vocab_size,
        masked=True if cfg.i2t.p0_dist == "masked" else False,
    )

    model_text = model_text.to(device, dtype=weight_dtype)
    ema_i2t = deepcopy(model_text).to(device)  # Create an EMA of the model for use after training
    requires_grad(ema_i2t, False)

    if accelerator.is_main_process:
        logger.info(f"Text Model Parameters: {sum(p.numel() for p in model_text.parameters()):,}")
    
    # Load pre-trained model 
    if cfg.i2t.pretrained_model_path is not None:
        ckpt = torch.load(cfg.i2t.pretrained_model_path, map_location='cpu')
        model_text.load_state_dict(ckpt['model'])

        if accelerator.is_main_process:
            logger.info(f"Loaded pre-trained text model from {cfg.i2t.pretrained_model_path}")

    # -------------------------------
    # Load VAE
    # -------------------------------
    vae = AutoencoderKL.from_pretrained(VAE_PATH).to(device, dtype=weight_dtype)

    latents_scale = torch.tensor(
        [0.18215, 0.18215, 0.18215, 0.18215]
        ).view(1, 4, 1, 1).to(device, dtype=weight_dtype)
    latents_bias = torch.tensor(
        [0., 0., 0., 0.]
        ).view(1, 4, 1, 1).to(device, dtype=weight_dtype)

    # -------------------------------
    # Define loss functions
    # -------------------------------
    loss_fn = SILoss_GM(
        t2i_path_type=cfg.t2i.prob_path,
        i2t_path_type=cfg.i2t.prob_path,
        i2t_path_exp=cfg.i2t.prob_path_exp,
        weighting=cfg.t2i.weighting,
        text_loss=cfg.i2t.loss,
        image_loss_weight=cfg.training.image_loss_weight,
        text_loss_weight=cfg.training.text_loss_weight,
        proj_coeff=cfg.t2i.proj_coeff,
        attn_coeff=cfg.t2i.attn_coeff,
        vocab_size=vocab_size,
        mask_token_id=mask_token_id if cfg.i2t.p0_dist == "masked" else None,
        cfg_prob=cfg.training.cfg_prob,
        prompt_prob=cfg.training.prompt_prob,
        text_guidance_prob=cfg.training.text_guidance_prob
    )
    
    # -------------------------------
    # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
    # -------------------------------
    if cfg.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    """
    optimizer = torch.optim.AdamW(
        list(model_t2i.parameters()) + list(model_text.parameters()),
        lr=cfg.training.learning_rate,
        betas=(cfg.training.adam_beta1, cfg.training.adam_beta2),
        weight_decay=cfg.training.adam_weight_decay,
        eps=cfg.training.adam_epsilon,
    )
    """
    if cfg.i2t.p0_dist == "masked":
        t2i_parameters = list(model_t2i.parameters()) + [embedding_layer.weight[mask_token_id]]
    else:
        t2i_parameters = list(model_t2i.parameters())
    print(f"Number of trainable T2I parameters: {sum(p.numel() for p in t2i_parameters):,}")
    optimizer = torch.optim.AdamW(
        [
            # LoRA adapters in UNet (higher LR)
            {
                "params": t2i_parameters,
                "lr": cfg.training.t2i_learning_rate,   # e.g. 1e-4
            },
            # Full text model (lower LR)
            {
                "params": list(model_text.parameters()),
                "lr": cfg.training.i2t_learning_rate,   # e.g. 1e-5
            },
        ],
        betas=(cfg.training.adam_beta1, cfg.training.adam_beta2),
        weight_decay=cfg.training.adam_weight_decay,
        eps=cfg.training.adam_epsilon,
    )

    # -------------------------------
    # Setup data
    # ------------------------------- 
    train_dataset = MSCOCO256FeaturesGM(path=cfg.dataset.params.data_dir)
    local_batch_size = int(cfg.dataset.dataloader.batch_size // accelerator.num_processes)
    train_dataloader = DataLoader(
        train_dataset.train,
        batch_size=local_batch_size,
        shuffle=cfg.dataset.dataloader.shuffle,
        num_workers=cfg.dataset.dataloader.num_workers,
        pin_memory=True,
        drop_last=True
    )
    if accelerator.is_main_process:
        logger.info(f"Dataset contains {len(train_dataset.train):,} images ({cfg.dataset.params.data_dir})")
    
    # -------------------------------
    # Prepare models for training
    # -------------------------------
    update_ema(ema_t2i, model_t2i, decay=0)  # Ensure EMA is initialized with synced weights
    model_t2i.train()  # important! This enables embedding dropout for classifier-free guidance
    ema_t2i.eval()  # EMA model should always be in eval mode

    update_ema(ema_i2t, model_text, decay=0)  # Ensure EMA is initialized with synced weights
    model_text.train()  # important! This enables embedding dropout for classifier-free guidance
    ema_i2t.eval()  # EMA model should always be in eval mode

    # -------------------------------
    # Resume from checkpoint
    # -------------------------------
    global_step = 0
    if cfg.resume_step > 0:
        ckpt_name = str(cfg.resume_step).zfill(7) +'.pt'
        ckpt = torch.load(
            f'{cfg.ckpt_dir}/checkpoints/{ckpt_name}',
            map_location='cpu',
            )
        model_t2i.load_state_dict(ckpt['model_t2i'])
        model_text.load_state_dict(ckpt['model_i2t'])
        #if cfg.i2t.p0_dist == "masked":
        #    clip_model.load_state_dict(ckpt['model_clip'])
        ema_t2i.load_state_dict(ckpt['ema_t2i'])
        ema_i2t.load_state_dict(ckpt['ema_i2t'])
        #optimizer.load_state_dict(ckpt['opt'])
        global_step = ckpt['steps']

    # -------------------------------
    # Prepare everything with Accelerator
    # -------------------------------
    if cfg.i2t.p0_dist == "masked":
        model_t2i, clip_model, model_text, optimizer, train_dataloader = accelerator.prepare(
            model_t2i, clip_model, model_text, optimizer, train_dataloader
        )
    else:
        model_t2i, model_text, optimizer, train_dataloader = accelerator.prepare(
            model_t2i, model_text, optimizer, train_dataloader
        )

    if accelerator.is_main_process:
        #tracker_config = vars(copy.deepcopy(cfg))
        accelerator.init_trackers(
            project_name="GM_train", 
            #config=tracker_config,
            init_kwargs={
                "wandb": {"name": f"{cfg.exp_name}"}
            },
        )
    
    # -------------------------------
    # Progress bar
    # -------------------------------
    progress_bar = tqdm(
        range(0, cfg.training.max_train_steps),
        initial=global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )

    # -------------------------------
    # Labels to condition the model with
    # -------------------------------
    sample_batch_size = 64 // accelerator.num_processes
    _, gt_xs, gt_prompt, _ = next(iter(train_dataloader))
    gt_xs = gt_xs[:sample_batch_size].to(dtype=weight_dtype)
    gt_prompt = gt_prompt[:sample_batch_size]

    gt_xs = sample_posterior(
        gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
        )
    gt_prompt_tok = tokenizer(gt_prompt, truncation=True, max_length=77, return_length=True,
                              return_overflowing_tokens=False, padding="max_length", return_tensors="pt")["input_ids"].to(device)

    # Create sampling noise:
    xT = torch.randn((sample_batch_size, 4, latent_size, latent_size), device=device).to(dtype=weight_dtype)
    xT_text = loss_fn.source_distribution_text.sample_like(gt_prompt_tok.to(device)).to(device)

    # -------------------------------
    # Training loop
    # -------------------------------
    for epoch in range(cfg.training.epochs):
        model_t2i.train()
        model_text.train()
        for raw_image, x, context_raw, context in train_dataloader:  
            #if global_step % args.sampling_steps == 0:
            #    ys = context[:sample_batch_size].to(device) # handed-coded
            raw_image = raw_image.to(device, dtype=weight_dtype)
            x = x.squeeze(dim=1).to(device, dtype=weight_dtype)
            context = context.to(device, dtype=weight_dtype)

            context_raw_tok = tokenizer(
                context_raw,
                truncation=True,
                max_length=77,
                return_length=True,
                return_overflowing_tokens=False,
                padding="max_length",
                return_tensors="pt"
            )["input_ids"].to(device)

            # Extract encoder features
            z = None
            with torch.no_grad():
                x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias).to(dtype=weight_dtype)
                zs = []
                proj_s = []
                with accelerator.autocast():
                    for encoder, encoder_type, arch in zip(encoders, encoder_types, architectures):
                        raw_image_ = preprocess_raw_image(
                            raw_image, encoder_type, resolution=cfg.t2i.resolution
                            ).to(dtype=weight_dtype)
                        z = encoder.forward_features_attn(raw_image_)
                        if 'dinov2' in encoder_type: 
                            proj = z['x_norm_patchtokens']
                            proj_s.append(proj)
                            z = z['attn_act'] # list
                            zs.append(z)

            with accelerator.accumulate(model_t2i), accelerator.accumulate(model_text):
                denoising_loss, attn_loss, proj_loss, image_loss, text_loss, total_loss = loss_fn(
                    image_model=model_t2i,
                    text_model=model_text,
                    latents=x,
                    images=raw_image.to(dtype=torch.float32),
                    captions=context_raw_tok,
                    clip_model=clip_model,
                    encoder=encoder,
                    encoder_type=encoder_type,
                    vae=vae,
                    latents_bias=latents_bias,
                    latents_scale=latents_scale,
                    device=device,
                    zs=zs,
                    proj_s=proj_s,
                    denoise_only=(global_step >= cfg.t2i.early_stop_point),
                    empty_context=train_dataset.empty_context,
                )
                    
                ## optimization
                accelerator.backward(total_loss)
                if accelerator.sync_gradients:
                    params_to_clip = list(model_t2i.parameters()) + list(model_text.parameters())
                    grad_norm = accelerator.clip_grad_norm_(params_to_clip, cfg.training.max_grad_norm)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

                if accelerator.sync_gradients:
                    update_ema(ema_t2i, model_t2i) # change ema function
                    update_ema(ema_i2t, model_text)

            ### enter
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1                
            if global_step % cfg.training.checkpointing_steps == 0 and global_step > 0:
                if accelerator.is_main_process:
                    if cfg.i2t.p0_dist == "masked":
                        checkpoint = {
                            "model_t2i": model_t2i.module.state_dict(),
                            "model_clip": clip_model.state_dict(),
                            "model_i2t": model_text.module.state_dict(),
                            "ema_t2i": ema_t2i.state_dict(),
                            "ema_i2t": ema_i2t.state_dict(),
                            "args": cfg,
                            "steps": global_step,
                            "epoch": epoch
                        }
                    else:
                        checkpoint = {
                            "model_t2i": model_t2i.module.state_dict(),
                            "model_i2t": model_text.module.state_dict(),
                            "ema_t2i": ema_t2i.state_dict(),
                            "ema_i2t": ema_i2t.state_dict(),
                            "args": cfg,
                            "steps": global_step,
                            "epoch": epoch
                        }
                    checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
                    torch.save(checkpoint, checkpoint_path)
                    logger.info(f"Saved checkpoint to {checkpoint_path}")

            if (global_step == 1 or (global_step % cfg.sampling.sampling_steps == 0 and global_step > 0)):
                with torch.no_grad():
                    samples_image, samples_text = euler_sampler(
                        image_model=model_t2i,
                        text_model=model_text,
                        initial_image=xT,
                        initial_text=xT_text,
                        vae=vae,
                        latents_scale=latents_scale,
                        latents_bias=latents_bias,
                        image_encoder=encoder,
                        tokenizer=tokenizer,
                        text_encoder=clip_model,
                        path=loss_fn.text_path,
                        num_steps=cfg.sampling.num_steps, 
                        cfg_scale=cfg.sampling.cfg_scale,
                        guidance_low=cfg.sampling.guidance_low,
                        guidance_high=cfg.sampling.guidance_high,
                        encoder_type=encoder_type,
                        extra_text_steps=cfg.sampling.extra_text_steps,
                        masked=True if cfg.i2t.p0_dist == "masked" else False
                    )
                    samples_image = vae.decode((samples_image.to(dtype=weight_dtype) -  latents_bias) / latents_scale).sample
                    gt_samples_image = vae.decode((gt_xs.to(dtype=weight_dtype) - latents_bias) / latents_scale).sample
                    samples_image = (samples_image + 1) / 2.
                    gt_samples_image = (gt_samples_image + 1) / 2.

                    samples_text = tokenizer.batch_decode(
                        samples_text,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=True
                    )

                out_samples = accelerator.gather(samples_image.to(torch.float32))
                gt_samples = accelerator.gather(gt_samples_image.to(torch.float32))
                printed_samples = [f"Sample {i}: {text}" for i, text in enumerate(samples_text)]
                logging.info(f"Generated samples at step {global_step}:\n" + "\n".join(printed_samples))

                accelerator.log({"samples": wandb.Image(array2grid(out_samples)),
                                 "gt_samples": wandb.Image(array2grid(gt_samples))})
                logging.info("Generating EMA samples done.")

            logs = {
                "total_loss": accelerator.gather(total_loss).mean().detach().item(), 
                "proj_loss": accelerator.gather(proj_loss).mean().detach().item(),
                "attn_loss": accelerator.gather(attn_loss).mean().detach().item(),
                "denoising_loss": accelerator.gather(denoising_loss).mean().detach().item(),
                "grad_norm": accelerator.gather(grad_norm).mean().detach().item(),
                "image_loss": accelerator.gather(image_loss).mean().detach().item(),
                "text_loss": accelerator.gather(text_loss).mean().detach().item(),
                "epoch": epoch
                }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)

            if global_step >= cfg.training.max_train_steps:
                break
        if global_step >= cfg.training.max_train_steps:
            break

    model_t2i.eval()  # important! This disables randomized embedding dropout
    # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
    model_text.eval()

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        logger.info("Done!")
    accelerator.end_training()


if __name__ == "__main__":
    main()