import os
import argparse
import random
import yaml
import shutil
from pathlib import Path
from math import inf
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from tqdm import tqdm
from PIL import Image
from diffusers.training_utils import compute_snr
from accelerate import Accelerator
from diffusers import DDPMScheduler, DPMSolverMultistepScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers.training_utils import EMAModel
from cleanfid import fid
from models.unet import UNET_MODELS
from models.vae_2 import VAE_MODELS
from data.dataset_loader import get_train_val_dataloaders
from utils.nd import NestedDropout
from accelerate.utils import DistributedDataParallelKwargs
import glob


def estimate_global_std(vae, dataloader, device):
    vae.eval()
    n, moment_2 = 0, 0.0
    with torch.no_grad():
        for imgs, _ in tqdm(dataloader, desc="Computing VAE scale factor", leave=False):
            z = vae.encode(imgs.to(device)).latent_dist.sample()
            moment_2 += (z ** 2).mean().item() * imgs.size(0)  # E[z²]
            n  += imgs.size(0)
    return (moment_2 / n) ** 0.5  # global σ

def count_parameters(module: torch.nn.Module) -> int:
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

def print_model_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Parameters: {total_params:,} ({total_params/1e6:.2f} M)")


def save_images_to_folder(images, folder_path, prefix="img"):
    """Save tensor images to a folder as PNG files for clean-fid, with batch CPU transfer."""
    os.makedirs(folder_path, exist_ok=True)
    
    # 1) Move entire batch to CPU and normalize to [0,1]
    #    Assume images in [-1,1] or [0,1]
    images = images.cpu()
    if images.min() < 0:
        images = (images + 1.0) / 2.0
    
    # 2) Convert whole tensor to uint8 [0,255]
    images_uint8 = (images * 255.0).clamp(0, 255).to(torch.uint8)  # shape: (N,C,H,W)
    
    # 3) Convert to NumPy once
    #    images_np: (N, H, W, C)
    images_np = images_uint8.permute(0, 2, 3, 1).numpy()
    
    # 4) Save each slice
    for i, img_arr in enumerate(images_np):
        # Handle grayscale vs. RGB
        if img_arr.shape[2] == 1:
            img_arr = img_arr.squeeze(2)
            img_pil = Image.fromarray(img_arr, mode='L')
        else:
            img_pil = Image.fromarray(img_arr, mode='RGB')
        
        img_path = os.path.join(folder_path, f"{prefix}_{i:06d}.png")
        img_pil.save(img_path)
# -----------------------------------------------------------------------------
# Trainer
# -----------------------------------------------------------------------------
class LDMTrainer:
    """Latent-Diffusion UNet trainer encapsulating the full training workflow."""

    def __init__(self, args: argparse.Namespace):
        self.args = args
        self._setup_device()
        self._load_config()
        self._seed_everything()
        self._build_dataloaders()
        self._build_models_and_optim()
        self._setup_monitoring()
        # ---- resume from previous state if available ----
        self.global_step = 0
        self.start_epoch = 1
        resume_path = self.task_dir / "resume_checkpoint.pth"
        if resume_path.exists():
            ckpt = torch.load(resume_path, map_location=self.device)
            # load model + ema + optim + scheduler
            self.accelerator.unwrap_model(self.model).load_state_dict(ckpt["model_state"])
            self.ema_model.load_state_dict(ckpt["ema_state"])
            self.optimizer.load_state_dict(ckpt["optimizer_state"])
            self.lr_scheduler.load_state_dict(ckpt["scheduler_state"])
            # restore metrics & counters
            self.best_mse     = ckpt.get("best_mse", self.best_mse)
            self.best_fid     = ckpt.get("best_fid", self.best_fid)
            self.global_step  = ckpt.get("global_step", 0)
            self.start_epoch  = ckpt.get("epoch", 0) + 1
            print(f"Resuming from epoch {self.start_epoch}, global_step {self.global_step}")
            print(f"Resumed best-FID ({self.best_fid:.2f}) ")
        
        if self.device.type == "cuda":
                torch.cuda.empty_cache()

        # Compute latent stats after Accelerator.prepare
        self.global_std= estimate_global_std(self.vae, self.train_dl, self.device)
        self.scale= 1.0/self.global_std
            

        
        
        

    def _setup_device(self):
        if self.args.gpus:
            os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpus
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            print(f"Using GPU(s): {self.args.gpus or 'all available'}")
        elif torch.backends.mps.is_available():
            self.device = torch.device("mps")
            print("Using Apple M-series (MPS)")
        else:
            self.device = torch.device("cpu")
            print("Using CPU – expect slow training!")

    def _load_config(self):
        with Path(self.args.config_path).open() as f:
            cfg = yaml.safe_load(f)
        self.ds_cfg      = cfg["dataset_params"]
        self.tr_cfg      = cfg["train_params"]
        self.denoise_steps    = self.tr_cfg["denoising_timesteps"]
        self.lr               = self.tr_cfg["lr"]
        self.warmup_steps     = self.tr_cfg.get("num_warmup_steps", 0)
        self.epochs           = self.tr_cfg["epochs"]
        self.log_epochs       = self.tr_cfg.get("log_epochs", 10)
        self.batch_size       = self.tr_cfg.get("batch_size", 64)
        self.val_batch        = self.tr_cfg.get("val_batch_size", 64)
        self.acc_steps        = self.tr_cfg.get("acc_steps", 2)
        self.val_samples      = self.ds_cfg.get("val_num_samples", 0)
        self.fid_interval     = self.tr_cfg.get("fid_interval", 10)
        self.ckpt_name        = self.tr_cfg.get("ldm_ckpt_name", "unet.pt")
        self.alpha            = self.tr_cfg.get('alpha',0.25)
        self.gamma            = self.tr_cfg.get('snr_gamma',5.0)
       


    def _seed_everything(self):
        seed = self.tr_cfg.get("seed", 42)
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        if self.device.type == "cuda":
            torch.cuda.manual_seed_all(seed)

    def _build_dataloaders(self):
        self.train_dl, self.val_loader = get_train_val_dataloaders(
            self.ds_cfg,
            train=True,
            train_batch_size=self.batch_size,
            val_batch_size=self.val_batch,
            val_num_samples=self.val_samples,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )
    
    def load_autoencoder(self,ckpt_path: str, model: torch.nn.Module, device: torch.device):
        """Load either a raw state‐dict or a full checkpoint with 'model_state'."""
        ckpt = torch.load(ckpt_path, map_location=device)
        state = ckpt.get("model_state", ckpt)
        model.load_state_dict(state)
        return model
    
    def _build_models_and_optim(self):
        # ------------------ VAE ------------------
        vae_model = VAE_MODELS()
        self.vae = vae_model.create_autoencoder_from_dataset(self.ds_cfg).to(self.device)
        self.vae = self.load_autoencoder(self.args.vae_ckpt_path, self.vae, self.device)
        for param in self.vae.parameters():
            param.requires_grad = False

        self.vae.eval()
       

        # nested dropout dummy
        with torch.no_grad():
            dummy = torch.zeros(
                1,
                self.ds_cfg["im_channels"],
                self.ds_cfg.get("im_size", 32),
                self.ds_cfg.get("im_size", 32),
                device=self.device,
            )
            self.z_dummy = self.vae.encode(dummy).latent_dist.sample()
        k = int(np.prod(self.z_dummy.shape[1:]))
        self.nd = NestedDropout(k, self.tr_cfg.get("drop_p", 1e-3), self.device).to(self.device)

        # ------------------ UNet ------------------
        unet_model = UNET_MODELS()
        self.model = unet_model.create_unet_from_dataset(self.ds_cfg)
        print_model_parameters(self.model)

        
        # gradient checkpointing
        if getattr(self.model, "enable_gradient_checkpointing", None):
            self.model.enable_gradient_checkpointing()
        elif getattr(self.model, "gradient_checkpointing_enable", None):
            self.model.gradient_checkpointing_enable()
        

        # ------------------ Optimizer ------------------
        self.optimizer = AdamW(self.model.parameters(), lr=self.lr)

        ddp_kwargs = DistributedDataParallelKwargs(
            gradient_as_bucket_view=True,
            find_unused_parameters=False  
        )

        # ------------------ Accelerator ------------------
      
        self.accelerator = Accelerator(
            gradient_accumulation_steps=self.acc_steps,
            kwargs_handlers=[ddp_kwargs]
        )
        self.device = self.accelerator.device
        print(
            f"Using device: {self.device} | "
            f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', 'None')}"
        )
        
        total_steps = math.ceil(self.epochs * len(self.train_dl) / self.acc_steps)
        self.lr_scheduler = get_cosine_schedule_with_warmup(self.optimizer, self.warmup_steps, total_steps)
        
        self.model, self.optimizer, self.train_dl, self.val_loader, self.lr_scheduler = self.accelerator.prepare(
            self.model, self.optimizer, self.train_dl, self.val_loader, self.lr_scheduler
            )
        # ------------------ Noise Schedulers & EMA ------------------

        self.noise_sched = DDPMScheduler(
                num_train_timesteps=self.denoise_steps,   
                beta_schedule="scaled_linear",
                prediction_type="epsilon",               
                clip_sample=False,
            )

        self.noise_sched_sample = DPMSolverMultistepScheduler.from_config(
                    self.noise_sched.config,
                    solver_order=2,
                    use_karras_sigmas=True,
                    algorithm_type="dpmsolver++",
                )
        
        self.ema_model = EMAModel(
            self.model.parameters(),
            decay=0.9999,
            use_ema_warmup=True,
            inv_gamma=1.0,
            power=0.75,
            min_decay=0.0
        )

  
    def _setup_monitoring(self):
        self.task_dir = Path(self.tr_cfg["task_name"])
        self.task_dir.mkdir(parents=True, exist_ok=True)
        (self.task_dir / "config_used.yaml").write_text(
            yaml.safe_dump({"dataset_params": self.ds_cfg, "train_params": self.tr_cfg})
        )

        script_src = Path(__file__)                    
        script_dst = self.task_dir / script_src.name  
        try:
            shutil.copy(script_src, script_dst)
            print(f"  Archived training script → {script_dst}")
        except Exception as e:
            print(f"  Warning: failed to archive script: {e}")

        self.tb = SummaryWriter(self.task_dir / "tb_logs")
        self.tb.add_scalar("Params/total",
            count_parameters(self.accelerator.unwrap_model(self.model)), 0)

        # Top-k tracking for MSE and FID
        self.top_mse_scores = []
        self.top_fid_scores = []
        self.k_best = 3

        # Legacy bests
        self.best_mse = inf
        self.best_fid = inf

        # Load existing top_scores.yaml if present
        scores_file = self.task_dir / "top_scores.yaml"
        if scores_file.exists():
            with open(scores_file, 'r') as f:
                data = yaml.safe_load(f)
            self.top_mse_scores = [(float(s), int(e)) for s, e in data.get('top_mse_scores', [])]
            self.top_fid_scores = [(float(s), int(e)) for s, e in data.get('top_fid_scores', [])]
            if self.top_mse_scores:
                self.best_mse = self.top_mse_scores[0][0]
            if self.top_fid_scores:
                self.best_fid = self.top_fid_scores[0][0]
            print(f"Loaded top scores: MSE={len(self.top_mse_scores)}, FID={len(self.top_fid_scores)}")
        
        

        # seed in-memory ckpt lists for resume
        self._top_fid_ckpts = []
        for rank, (score, epoch) in enumerate(self.top_fid_scores, start=1):
            path = self.task_dir / f"fid_rank_{rank}_{self.ckpt_name}"
            if not path.exists():
                continue
            loc = self.device if rank == 1 else "cpu"
            ck = torch.load(path, map_location=loc)
            self._top_fid_ckpts.append((score, epoch, ck))

        self._top_mse_ckpts = []
        for rank, (score, epoch) in enumerate(self.top_mse_scores, start=1):
            path = self.task_dir / f"mse_rank_{rank}_{self.ckpt_name}"
            if not path.exists():
                continue
            loc = self.device if rank == 1 else "cpu"
            ck = torch.load(path, map_location=loc)
            self._top_mse_ckpts.append((score, epoch, ck))
        
         # If there's no resume_checkpoint but we have a best-FID, load that into GPU
        resume_path = self.task_dir / "resume_checkpoint.pth"
        if not resume_path.exists() and self._top_fid_ckpts:
            best_score, best_epoch, best_ck = self._top_fid_ckpts[0]
            self.best_fid = float(best_score)
            self.accelerator.unwrap_model(self.model).load_state_dict(best_ck["model_state"])
            self.ema_model.load_state_dict(best_ck["ema_state"])
            self.optimizer.load_state_dict(best_ck["optimizer_state"])
            self.lr_scheduler.load_state_dict(best_ck["scheduler_state"])
            print(f"Resumed best-FID ({self.best_fid:.2f}) from fid_rank_1")


    def _save_top_scores(self):
        """Persist the two lists `top_mse_scores` and `top_fid_scores` to YAML."""
        data = {
            # use lists, not tuples, and ensure Python float/int
            "top_mse_scores": [[float(s), int(e)] for s, e in self.top_mse_scores],
            "top_fid_scores": [[float(s), int(e)] for s, e in self.top_fid_scores],
        }
        with open(self.task_dir / "top_scores.yaml", "w") as f:
            yaml.safe_dump(data, f)


    def _update_and_save_top_fid(self, score: float, epoch: int, ckpt: dict):
        """
        Maintain the best-3 FID checkpoints on disk, bumping/dropping as needed.
        """
        if not hasattr(self, "_top_fid_ckpts"):
            self._top_fid_ckpts = []
        self._top_fid_ckpts.append((score, epoch, ckpt))
        self._top_fid_ckpts.sort(key=lambda x: x[0])
        self._top_fid_ckpts = self._top_fid_ckpts[: self.k_best]

        self.top_fid_scores = [(s, e) for s, e, _ in self._top_fid_ckpts]
        if self.top_fid_scores:
            self.best_fid = self.top_fid_scores[0][0]

        pattern = str(self.task_dir / f"fid_rank_*_{self.ckpt_name}")
        for path in glob.glob(pattern):
            os.remove(path)

        for rank, (s, e, c) in enumerate(self._top_fid_ckpts, start=1):
            c["fid_score"] = s
            c["epoch"]     = e
            out_path = self.task_dir / f"fid_rank_{rank}_{self.ckpt_name}"
            torch.save(c, out_path)
            print(f"  Saved fid rank {rank} (epoch {e}, FID={s:.4f}) → {out_path}")

        self._save_top_scores()

    def _update_and_save_top_mse(self, score: float, epoch: int, ckpt: dict):
        """
        Maintain the best-3 MSE checkpoints on disk, bumping/dropping as needed.
        """
        if not hasattr(self, "_top_mse_ckpts"):
            self._top_mse_ckpts = []
        self._top_mse_ckpts.append((score, epoch, ckpt))
        self._top_mse_ckpts.sort(key=lambda x: x[0])
        self._top_mse_ckpts = self._top_mse_ckpts[: self.k_best]

        self.top_mse_scores = [(s, e) for s, e, _ in self._top_mse_ckpts]
        if self.top_mse_scores:
            self.best_mse = self.top_mse_scores[0][0]

        pattern = str(self.task_dir / f"mse_rank_*_{self.ckpt_name}")
        for path in glob.glob(pattern):
            os.remove(path)

        for rank, (s, e, c) in enumerate(self._top_mse_ckpts, start=1):
            c["mse_score"] = s
            c["epoch"]     = e
            out_path = self.task_dir / f"mse_rank_{rank}_{self.ckpt_name}"
            torch.save(c, out_path)
            print(f"  Saved mse rank {rank} (epoch {e}, MSE={s:.4f}) → {out_path}")

        self._save_top_scores()
    
    def encode_images(self,images):
        """Encode Images to Latent Space"""
        with torch.no_grad():
            latents= self.vae.encode(images).latent_dist.sample()
        
        return latents*self.scale
    
    def decode_images(self,latents):
        """Decode latents back to image space"""
        with torch.no_grad():
            latents=latents/self.scale
            images=self.vae.decode(latents).sample
        
        return images
    
    def generate_images(self,latent_dummy, vae, unet, noise_scheduler,
                        batch_size,device):
        """Sample *batch_size* images via the reverse diffusion process."""
        vae.eval()
        unet.eval()
        latent_shape=latent_dummy.shape[1:]
        latents = torch.randn((batch_size,*latent_shape), device=device)
        noise_scheduler.set_timesteps(100, device=device)
        for t in noise_scheduler.timesteps:
            with torch.no_grad():
                noise_pred = unet(latents, t).sample
                latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
        images=self.decode_images(latents)
        return ((images + 1) / 2).clamp(0,1) 

    # ------------------------------- training loop -------------------------------
    def train(self):
        for ep in range(self.start_epoch, self.epochs + 1):
            self.model.train()
            pbar = tqdm(
                self.train_dl,
                desc=f"Epoch {ep}/{self.epochs}",
                disable=not self.accelerator.is_local_main_process
            )

            for images, _ in pbar:
                images = images.to(self.device)
                with self.accelerator.accumulate(self.model):
                    latents=self.encode_images(images)
                    # ─── 2) Sample noise and timesteps ──────────────────────────────────
                    noise = torch.randn_like(latents)
                    t     = torch.randint(0, self.denoise_steps, (images.size(0),), device=self.device)
                    noisy = self.noise_sched.add_noise(latents, noise, t)
                    snr = compute_snr(self.noise_sched, t)
                    mse_loss_weights = torch.stack([snr, self.gamma * torch.ones_like(t)], dim=1).min(dim=1)[0]
                    mse_loss_weights [snr == 0] = 1.0  
        

                    # ─── Model predictions & loss computation ──────────────────────────
                    if self.tr_cfg.get("nd"):
                        # Your noise dropout logic
                        noisy_nd = self.nd(noisy)
                        pred = self.model(sample=noisy, timestep=t).sample
                        pred_nd = self.model(sample=noisy_nd, timestep=t).sample
                        
                        loss_normal = F.mse_loss(pred, noise, reduction="none")
                        loss_normal = loss_normal.mean(dim=list(range(1, len(loss_normal.shape))))
                        
                        loss_nd = F.mse_loss(pred_nd, noise, reduction="none")
                        loss_nd = loss_nd.mean(dim=list(range(1, len(loss_nd.shape))))
                        
                        # Apply Min-SNR weighting
                        weighted_weights = mse_loss_weights / snr  # For epsilon prediction
                        weighted_loss_normal = (weighted_weights * loss_normal).mean()
                        weighted_loss_nd = (weighted_weights * loss_nd).mean()
                        
                        loss = (1 - self.alpha) * weighted_loss_normal + self.alpha * weighted_loss_nd
                        
                    else:
                        pred = self.model(sample=noisy, timestep=t).sample
                        
    
                        weighted_weights = mse_loss_weights / snr
                      
                        
                        # Compute MSE loss without reduction first
                        loss = F.mse_loss(pred, noise, reduction="none")
                        # Average over all dimensions except batch
                        loss = loss.mean(dim=list(range(1, len(loss.shape))))
                        # Apply Min-SNR weighting
                        loss = (weighted_weights * loss).mean()
                        

                    # ─── 7) Backward + optimizer/EMA update ─────────────────────────────
                    self.accelerator.backward(loss)
                    if self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
                        self.optimizer.step()
                        self.lr_scheduler.step()
                        self.optimizer.zero_grad()
                        self.ema_model.step(self.model.parameters())
                        self.global_step += 1

                    # ─── 8) TensorBoard logging ────────────────────────────────────────
                    if self.accelerator.is_local_main_process:
                        #self.tb.add_scalar("Loss/train_mse",   loss_per.mean().item(), self.global_step)
                        self.tb.add_scalar("Loss/train_total", loss.item(),             self.global_step)
                        self.tb.add_scalar("LR",  self.lr_scheduler.get_last_lr()[0] , self.global_step)
                        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
            
           

            # ─── End of epoch: validation & sample logging ─────────────────────────
            if self.accelerator.is_local_main_process:
                self._validate_and_log(ep)
                if ep % self.log_epochs == 0:
                    self._log_samples(ep)

        # ─── After all epochs: save final checkpoints ────────────────────────────
        if self.accelerator.is_local_main_process:
            self._save_final()

        
    def convert_to_rgb(self, images):
        if images.shape[1] == 1:
            return images.repeat(1, 3, 1, 1)
        return images

    # ------------------------------- validation & logging -------------------------------
    def _validate_and_log(self, epoch: int):
        # switch to EMA weights
        self.ema_model.store(self.model.parameters())
        self.ema_model.copy_to(self.model.parameters())

        # --- compute validation MSE ---
        self.model.eval()
        val_losses = []
        with torch.no_grad():
            for vimgs, _ in self.val_loader:
                vimgs = vimgs.to(self.device)
                vz = self.vae.encode(vimgs).latent_dist.sample() * self.scale
                vnoise = torch.randn_like(vz)
                vt = torch.randint(0, self.denoise_steps, (vz.size(0),), device=self.device)
                noisy_vz = self.noise_sched.add_noise(vz, vnoise, vt)
                vp = self.model(sample=noisy_vz, timestep=vt).sample
                val_losses.append(F.mse_loss(vp, vnoise).item())
        avg_val = sum(val_losses) / len(val_losses)

        # restore original (non-EMA) weights
        self.ema_model.restore(self.model.parameters())

        # log MSE
        self.tb.add_scalar("Val/mse", avg_val, epoch)

        # --- on FID interval, compute FID and checkpoint both metrics ---
        if epoch % self.fid_interval == 0:
            real_folder = self.task_dir / f"temp_real_epoch_{epoch}"
            fake_folder = self.task_dir / f"temp_fake_epoch_{epoch}"
            if real_folder.exists(): shutil.rmtree(real_folder)
            if fake_folder.exists(): shutil.rmtree(fake_folder)
            

            # ---- switch to EMA weights for sampling ------------------------------
            self.ema_model.store(self.model.parameters())
            self.ema_model.copy_to(self.model.parameters())
            self.model.eval()

             # collect images
            batch_count = 0
            with torch.no_grad():
                for real_imgs, _ in self.val_loader:
                    real_imgs = real_imgs.to(self.device)
                    bs = real_imgs.size(0)

                    real_01 = ((real_imgs + 1) / 2).clamp(0,1)
                    real_rgb = self.convert_to_rgb(real_01)

                    fake_01 = self.generate_images(
                        self.z_dummy,
                        self.vae, self.model,
                        self.noise_sched_sample,
                        bs, self.device)
                    
                    fake_rgb = self.convert_to_rgb(fake_01)

                    save_images_to_folder(real_rgb, real_folder, f"real_{batch_count}")
                    save_images_to_folder(fake_rgb, fake_folder, f"fake_{batch_count}")
                    batch_count += 1

            fid_score = fid.compute_fid(str(real_folder), str(fake_folder),mode='clean')
            shutil.rmtree(real_folder)
            shutil.rmtree(fake_folder)
   
            # restore original weights
            self.ema_model.restore(self.model.parameters())

            # log FID
            self.tb.add_scalar("Val/FID", fid_score, epoch)
            print(f"Epoch {epoch} → Val MSE: {avg_val:.4f}, FID: {fid_score:.2f}")

            # checkpoint dict
            ckpt = {
                "model_state":      self.accelerator.unwrap_model(self.model).state_dict(),
                "ema_state":        self.ema_model.state_dict(),
                "optimizer_state":  self.optimizer.state_dict(),
                "scheduler_state":  self.lr_scheduler.state_dict(),
            }

            # 1) update & save top-3 MSE
            self._update_and_save_top_mse(avg_val, epoch, ckpt.copy())
            # 2) update & save top-3 FID
            self._update_and_save_top_fid(fid_score, epoch, ckpt.copy())

            # 3) if this epoch is now your #1-FID, write resume_checkpoint.pth
            if self.top_fid_scores and self.top_fid_scores[0][1] == epoch:
                resume_ckpt = {
                    "epoch":       epoch,
                    "global_step": self.global_step,
                    "model_state": self.accelerator.unwrap_model(self.model).state_dict(),
                    "ema_state":   self.ema_model.state_dict(),
                    "optimizer_state": self.optimizer.state_dict(),
                    "scheduler_state": self.lr_scheduler.state_dict(),
                    "best_mse":    self.best_mse,
                    "best_fid":    self.best_fid,
                }
                torch.save(resume_ckpt, self.task_dir / "resume_checkpoint.pth")
                print(f"  Saved resume checkpoint → {self.task_dir/'resume_checkpoint.pth'}")



    def _log_samples(self, epoch: int):
        self.ema_model.store(self.model.parameters())
        self.ema_model.copy_to(self.model.parameters())
        with torch.no_grad():
            samples = self.generate_images(self.z_dummy,
                self.vae, self.model,
                self.noise_sched_sample,
                32, self.device
            )
        self.ema_model.restore(self.model.parameters())
        grid = make_grid(samples, nrow=8, normalize=True)
        self.tb.add_image("Images/Generated", grid, epoch)

    def _save_final(self):
        print("\nTraining complete – saving final checkpoints …")
        torch.save(self.accelerator.unwrap_model(self.model).state_dict(), self.task_dir / "unet_final.pt")
        torch.save(self.ema_model.state_dict(), self.task_dir / "ema_final.pt")
        self.tb.close()
        print(" Saved: unet_final.pt  |  ema_final.pt")



if __name__ == "__main__":
    parser = argparse.ArgumentParser("Train Latent Diffusion Model (UNet stage)")
    parser.add_argument("--config", "-c", dest="config_path", required=True,
                        help="Path to YAML config")
    parser.add_argument("--vae_ckpt_path", "-v", required=True,
                        help="Pre-trained VAE checkpoint")
    parser.add_argument("--gpus", "-g", default=None,
                        help="Comma-separated GPU ids (optional)")
    cli_args = parser.parse_args()
    trainer = LDMTrainer(cli_args)
    trainer.train()