import os
import argparse
import logging
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from copy import deepcopy
from time import time
from omegaconf import OmegaConf

from diffusion import create_diffusion
from diffusion.rectified_flow_ori import RectifiedFlow
from download import find_model
from utils import (
    update_ema, requires_grad, cleanup, setup_ddp, setup_exp_dir,
    setup_data, instantiate_from_config, get_lr_scheduler_config
)


def save_checkpoint(checkpoint_dir, step, model, ema, opt, args, logger):
    """Save training checkpoint (only on rank 0)."""
    if dist.get_rank() == 0:
        checkpoint = {
            "model": model.module.state_dict(),
            "ema": ema.state_dict(),
            "opt": opt.state_dict(),
            "args": args
        }
        checkpoint_path = os.path.join(checkpoint_dir, f"{step:07d}.pt")
        torch.save(checkpoint, checkpoint_path)
        logger.info(f"Saved checkpoint to {os.path.abspath(checkpoint_path)}")


def main(config_path):
    # -----------------------------
    # Setup
    # -----------------------------
    config = OmegaConf.load(config_path)
    rank, device, seed = setup_ddp(config.basic)
    logger, writer, checkpoint_dir = setup_exp_dir(rank, config)

    # -----------------------------
    # Model & Diffusion
    # -----------------------------
    model = instantiate_from_config(config.model)
    if config.model.ckpt is not None:
        state_dict = find_model(config.model.ckpt, is_train=True)
        model.load_state_dict(state_dict)
        logger.info(f"Loaded model checkpoint from {config.model.ckpt}")

    logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")

    ema = deepcopy(model).to(device)
    requires_grad(ema, False)

    model = DDP(model.to(device), device_ids=[rank])
    ema.eval()

    if not hasattr(config.basic, "rf"):
        config.basic.rf = False

    if config.basic.rf:
        logger.info("Training with Rectified Flow...")
        diffusion = RectifiedFlow(model)
    else:
        diffusion = create_diffusion(
            timestep_respacing="",
            predict_xstart=config.basic.predict_xstart
        )

    # -----------------------------
    # Optimizer & LR Scheduler
    # -----------------------------
    opt = torch.optim.AdamW(
        model.parameters(),
        lr=config.optim.base_learning_rate,
        weight_decay=config.optim.weight_decay,
        betas=config.optim.betas,
    )
    max_grad_norm = config.basic.clip_grad_norm

    # Warmup scheduler
    warmup_steps = -1
    if config.lr_sheduler.get("warmup", None):
        warmup_cfg = get_lr_scheduler_config(config.lr_sheduler.warmup, opt)
        lr_scheduler_warmup = instantiate_from_config(warmup_cfg)
        warmup_steps = config.lr_sheduler.warmup.params.warmup_steps
    else:
        lr_scheduler_warmup = None

    # Epoch scheduler
    use_epoch_lr_scheduler = False
    if config.lr_sheduler.get("train_epoch", None):
        milestones = [
            int(m * config.basic.epochs)
            for m in config.lr_sheduler.train_epoch.params["milestones"]
        ]
        config.lr_sheduler.train_epoch.params["milestones"] = milestones
        epoch_cfg = get_lr_scheduler_config(config.lr_sheduler.train_epoch, opt)
        lr_scheduler_train_epoch = instantiate_from_config(epoch_cfg)
        use_epoch_lr_scheduler = True
    else:
        lr_scheduler_train_epoch = None

    logger.info(f"GPU NUM: {dist.get_world_size()}")

    # -----------------------------
    # Dataset
    # -----------------------------
    dataset, sampler, loader = setup_data(rank, config.basic)
    logger.info(f"Dataset contains {len(dataset):,} images ({config.basic.data_path})")

    # DinoDecoder encoder
    from SVG.svg_diffusion.ldm.models._decoder import DinoDecoder
    encoder_config = OmegaConf.load(config.basic.encoder_config)
    svg_autoencoder = DinoDecoder(
        ddconfig=encoder_config.model.params.ddconfig,
        dinoconfig=encoder_config.model.params.dinoconfig,
        lossconfig=encoder_config.model.params.lossconfig,
        embed_dim=encoder_config.model.params.embed_dim,
        ckpt_path=encoder_config.ckpt_path,
        extra_vit_config=encoder_config.model.params.extra_vit_config,
    ).cuda().eval()

    # Initialize EMA with synced weights
    update_ema(ema, model.module, decay=0)
    model.train()

    # -----------------------------
    # Training Loop
    # -----------------------------
    timestep_start = config.basic.get("timestep_start", 0)
    timestep_end = getattr(diffusion, "num_timesteps", None)
    if not (config.basic.rf or getattr(config.basic, "rf_ori", False)):
        logger.info(f"Training diffusion steps from {timestep_start} to {timestep_end}")

    global_step, train_steps, log_steps = 0, 0, 0
    running_loss = {}
    accum_iter = config.basic.accum_iter
    log_every = config.basic.log_every
    ckpt_every = config.basic.ckpt_every
    start_time = time()

    dinov3_stats = torch.load("dinov3_sp_stats.pt")
    dinov3_mean = dinov3_stats["dinov3_sp_mean"].to(device)[:, :, :encoder_config.model.params.ddconfig.z_channels]
    dinov3_std = dinov3_stats["dinov3_sp_std"].to(device)[:, :, :encoder_config.model.params.ddconfig.z_channels]

    logger.info(f"Training for {config.basic.epochs} epochs...")
    for epoch in range(config.basic.epochs):
        sampler.set_epoch(epoch)
        logger.info(f"Beginning epoch {epoch}...")

        for x, y in loader:
            if hasattr(model.module, "training_iters"):
                model.module.training_iters += 1

            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                x = svg_autoencoder.encode(x)
                B, D, H, W = x.shape
                x = x.view(B, D, H * W).permute(0, 2, 1).contiguous()

            if config.basic.feature_norm:
                x = (x - dinov3_mean) / dinov3_std

            if config.basic.rf:
                loss_dict = diffusion.forward(x, y, config.basic.shift)
            else:
                t = torch.randint(0, diffusion.num_timesteps, (x.size(0),), device=device)
                loss_dict = diffusion.training_losses(model, x, t, model_kwargs=dict(y=y))

            if "RaceAll" in loss_dict:
                if global_step % (log_every * 10) == 0:
                    logger.info(loss_dict.pop("RaceAll"))

            loss = loss_dict["loss"].mean()
            loss.backward()

            if (global_step + 1) % accum_iter == 0:
                if max_grad_norm:
                    clip_grad_norm_(model.parameters(), max_grad_norm)
                opt.step()
                opt.zero_grad()
                update_ema(ema, model.module)
                train_steps += 1
                if lr_scheduler_warmup and train_steps <= warmup_steps:
                    lr_scheduler_warmup.step()

            # Accumulate loss
            for k, v in loss_dict.items():
                running_loss[k] = running_loss.get(k, 0) + v.mean().item() / accum_iter

            log_steps += 1
            global_step += 1

            # Logging
            if global_step % (log_every * accum_iter) == 0:
                torch.cuda.synchronize()
                end_time = time()
                steps_per_sec = log_steps / (end_time - start_time)

                log_msg = f"(Global Step={global_step:08d}, Train Step={train_steps:08d}) "
                for k, v in running_loss.items():
                    avg_loss = torch.tensor(v / log_every, device=device)
                    dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
                    avg_loss = avg_loss.item() / dist.get_world_size()
                    log_msg += f"{k}: {avg_loss:.4f}, "
                    if rank == 0:
                        writer.add_scalar(k, avg_loss, train_steps)

                log_msg += f"LR: {opt.param_groups[0]['lr']}, Steps/Sec: {steps_per_sec:.2f}"
                logger.info(log_msg)

                if rank == 0:
                    writer.add_scalar("lr", opt.param_groups[0]["lr"], train_steps)

                running_loss.clear()
                log_steps, start_time = 0, time()

            # Checkpoint
            if global_step % (ckpt_every * accum_iter) == 0:
                save_checkpoint(checkpoint_dir, train_steps, model, ema, opt, config, logger)
                logger.info(f"Batch size per device = {y.size(0)}")
                dist.barrier()

        if use_epoch_lr_scheduler:
            lr_scheduler_train_epoch.step()
            logger.info(f"Adjusted lr to {opt.param_groups[0]['lr']}")

    # -----------------------------
    # Finalize
    # -----------------------------
    model.eval()
    logger.info("Training complete!")
    cleanup()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    args = parser.parse_args()
    main(args.config)
