
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Standalone evaluation sampler for OT-Bridge (val split only).

- No training, no metrics.
- Loads checkpoint (EMA if available).
- Builds corruption (e.g., dogms) and runs DDPM sampling on val LMDB.
- Saves PNG grids (and optional per-sample GIFs).

Example:
    CUDA_VISIBLE_DEVICES=0 python eval_val_only.py \
      --ckpt /path/to/runs/cag_test/latest.pt \
      --dataset-dir /mnt/CAG_Dataset/datasets/CAG_10K_256_lmdb \
      --out-dir ./eval_out/cag_test --name cag_test \
      --corrupt dogms --cond-x1 --nfe 999 --clip-denoise --save-gif
"""

import argparse
import logging
import pickle
from types import SimpleNamespace as NS
from pathlib import Path

import torch
import torchvision.utils as tu

# Optional: GIF writer
try:
    import imageio.v2 as imageio
except Exception:
    imageio = None

# Project imports (run from OTBridge project root or ensure PYTHONPATH)
from ot_bridge.network import Image256Net
from ot_bridge.diffusion import Diffusion
from ot_bridge import util
from torch_ema import ExponentialMovingAverage
from dataset import imagenet
from corruption import build_corruption

import numpy as np

from pathlib import Path
from types import SimpleNamespace as NS


def setup_logger():
    log = logging.getLogger("eval_only")
    if not log.handlers:
        log.setLevel(logging.INFO)
        h = logging.StreamHandler()
        h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s", datefmt="%H:%M:%S"))
        log.addHandler(h)
    return log


def make_beta_schedule(n_timestep=1000, linear_start=1e-4, linear_end=2e-2):
    # torch for numerical stability; Diffusion expects numpy array
    betas = (torch.linspace(linear_start ** 0.5, linear_end ** 0.5,
                            n_timestep, dtype=torch.float64) ** 2)
    return betas.numpy()


def to_vis(x):  # [-1,1] -> [0,1]
    return (x.clamp(-1, 1) + 1) / 2


def save_png_grid(out_dir: Path, it: int, tag: str, img: torch.Tensor, nrow=10, log=None):
    out_dir.mkdir(parents=True, exist_ok=True)
    grid = tu.make_grid(to_vis(img), nrow=nrow)
    fp = out_dir / f"{it:07d}_{tag}.png"
    tu.save_image(grid, fp)
    if log: log.info(f"[save] {fp}")


def save_traj_frames_and_gif(out_dir: Path, it: int, xs: torch.Tensor, pred_x0s: torch.Tensor,
                             k_list=(0,), every=1, fps=8, flip_time=True, log=None):
    """
    xs/pred_x0s: [B, T, 3, H, W] on CPU
    """
    out_dir.mkdir(parents=True, exist_ok=True)
    B, T, C, H, W = xs.shape
    t_indices = list(range(T))[::-1] if flip_time else list(range(T))

    for k in k_list:
        sub = out_dir / f"{it:07d}_sample{k}"
        sub.mkdir(parents=True, exist_ok=True)
        pair_paths = []
        for i, t in enumerate(t_indices[::every]):
            recon = to_vis(xs[k, t])
            pred  = to_vis(pred_x0s[k, t])
            # singles
            tu.save_image(recon, sub / f"recon_t{i:03d}.png")
            tu.save_image(pred,  sub / f"pred_x0_t{i:03d}.png")
            # side-by-side
            pair = torch.cat([recon, pred], dim=2)
            pth = sub / f"pair_t{i:03d}.png"
            tu.save_image(pair, pth)
            pair_paths.append(pth)
        # GIF
        if imageio is not None and len(pair_paths) > 1:
            imgs = [imageio.imread(p) for p in pair_paths]
            gif_path = sub / "traj_pair.gif"
            imageio.mimsave(gif_path, imgs, duration=1.0/max(fps,1))
            if log: log.info(f"[GIF] {gif_path}")
        elif log:
            log.info(f"[frames] saved under {sub} (imageio not found, skip GIF)")


@torch.no_grad()
def ddpm_sampling(opt, net, ema, diffusion, x1, mask=None, cond=None, clip_denoise=False, nfe=None, log_count=10, log=None):
    # time steps
    nfe = nfe or (opt.interval - 1)
    assert 0 < nfe < opt.interval == len(diffusion.betas)
    steps = util.space_indices(opt.interval, nfe + 1)

    # log steps
    log_count = min(len(steps) - 1, log_count)
    log_steps = [steps[i] for i in util.space_indices(len(steps) - 1, log_count)]
    if log: log.info(f"[DDPM Sampling] steps={opt.interval}, nfe={nfe}, log_steps={log_steps}!")

    x1 = x1.to(opt.device)
    if cond is not None: cond = cond.to(opt.device)
    if mask is not None:
        mask = mask.to(opt.device)
        x1 = (1. - mask) * x1 + mask * torch.randn_like(x1)

    def pred_x0_fn(xt, step):
        step = torch.full((xt.shape[0],), step, device=opt.device, dtype=torch.long)
        out = net(xt, step, cond=cond)
        std_fwd = diffusion.get_std_fwd(step, xdim=xt.shape[1:])
        pred_x0 = xt - std_fwd * out
        if clip_denoise:
            pred_x0 = pred_x0.clamp(-1., 1.)
        return pred_x0

    with ema.average_parameters():
        net.eval()
        xs, pred_x0 = diffusion.ddpm_sampling(
            steps, pred_x0_fn, x1, mask=mask, ot_ode=opt.ot_ode, log_steps=log_steps, verbose=True,
        )
    b, *xdim = x1.shape
    assert xs.shape == pred_x0.shape == (b, log_count, *xdim)
    return xs, pred_x0


def build_opt_from_args(args, loaded_opt=None):
    # Compose a minimal opt namespace that the corruption/model expect
    opt = NS()
    opt.device = torch.device("cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
    opt.interval = getattr(loaded_opt, "interval", args.interval)
    opt.beta_max = getattr(loaded_opt, "beta_max", args.beta_max)
    opt.t0 = getattr(loaded_opt, "t0", args.t0)
    opt.T = getattr(loaded_opt, "T", args.T)
    opt.ot_ode = getattr(loaded_opt, "ot_ode", args.ot_ode)
    opt.clip_denoise = getattr(loaded_opt, "clip_denoise", args.clip_denoise)
    opt.use_fp16 = getattr(loaded_opt, "use_fp16", args.use_fp16)
    opt.cond_x1 = getattr(loaded_opt, "cond_x1", args.cond_x1)
    opt.log_dir = str(args.out_dir)
    opt.name = args.name or (Path(args.ckpt).parent.name if args.ckpt else "eval")
    opt.dataset_dir = args.dataset_dir
    opt.image_size = getattr(loaded_opt, "image_size", args.image_size)
    opt.distributed = False
    opt.global_rank = 0
    opt.corrupt = args.corrupt or getattr(loaded_opt, "corrupt", "dogms")
    opt.add_x1_noise = getattr(loaded_opt, "add_x1_noise", False)
    return opt


def main():
    parser = argparse.ArgumentParser("OT-Bridge Eval-Only Sampler (val split)")
    parser.add_argument("--ckpt", required=True, help="Path to checkpoint .pt (e.g., runs/<exp>/latest.pt)")
    parser.add_argument("--dataset-dir", required=True, help="LMDB root containing val_faster_imagefolder.lmdb(.pt)")
    parser.add_argument("--out-dir", default="./eval_out", help="Output folder for PNG/GIF")
    parser.add_argument("--name", default=None, help="Experiment name for output subdir")
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--num-workers", type=int, default=4)
    parser.add_argument("--max-batches", type=int, default=0, help="0=all")
    parser.add_argument("--image-size", type=int, default=256)
    parser.add_argument("--interval", type=int, default=1000)
    parser.add_argument("--beta-max", type=float, default=10.0)
    parser.add_argument("--t0", type=float, default=0.02)
    parser.add_argument("--T", type=float, default=1.0)
    parser.add_argument("--nfe", type=int, default=None, help="Number of function evaluations; default interval-1")
    parser.add_argument("--use-fp16", action="store_true")
    parser.add_argument("--cpu", action="store_true")
    parser.add_argument("--ot-ode", action="store_true")
    parser.add_argument("--clip-denoise", action="store_true")
    parser.add_argument("--cond-x1", action="store_true")
    parser.add_argument("--corrupt", default="dogms", help="Corruption key, e.g., dogms or jpeg-<q> etc.")
    parser.add_argument("--save-gif", action="store_true")
    parser.add_argument("--k-samples", type=int, default=3, help="How many samples (rows) to export as GIF")
    args = parser.parse_args()

    log = setup_logger()
    out_root = Path(args.out_dir) / (args.name or Path(args.ckpt).stem)
    (out_root / "images").mkdir(parents=True, exist_ok=True)
    (out_root / "frames").mkdir(parents=True, exist_ok=True)

    # Try to load options.pkl from same folder as ckpt
    loaded_opt = None
    ckpt_path = Path(args.ckpt)
    if ckpt_path.is_file():
        opt_dir = ckpt_path.parent
    else:
        opt_dir = ckpt_path
        args.ckpt = str(opt_dir / "latest.pt")
    opt_pkl = opt_dir / "options.pkl"
    if opt_pkl.exists():
        try:
            with open(opt_pkl, "rb") as f:
                loaded_opt = pickle.load(f)
            log.info(f"Loaded options from {opt_pkl}")
        except Exception as e:
            log.warning(f"Failed to load {opt_pkl}: {e}")

    # Build opt
    opt = build_opt_from_args(args, loaded_opt)
    log.info(f"Device: {opt.device}")

    # Build diffusion + model
    betas = make_beta_schedule(n_timestep=opt.interval, linear_end=opt.beta_max / opt.interval)
    # symmetric schedule like training
    # betas = torch.from_numpy(
    #     torch.cat([betas[:opt.interval//2], torch.flip(betas[:opt.interval//2], dims=[0])]).numpy()
    # ).numpy()
    betas = np.concatenate([betas[:opt.interval//2], np.flip(betas[:opt.interval//2], axis=0)])
    diffusion = Diffusion(betas, opt.device)
    noise_levels = torch.linspace(opt.t0, opt.T, opt.interval, device=opt.device) * opt.interval
    net = Image256Net(log, noise_levels=noise_levels, use_fp16=opt.use_fp16, cond=opt.cond_x1, image_size=opt.image_size).to(opt.device)
    ema = ExponentialMovingAverage(net.parameters(), decay=getattr(loaded_opt, "ema", 0.9999))

    # Load checkpoint
    ckpt = torch.load(args.ckpt, map_location="cpu")
    if "net" in ckpt:
        net.load_state_dict(ckpt["net"])
        log.info(f"Loaded net weights from {args.ckpt}")
    if "ema" in ckpt:
        try:
            ema.load_state_dict(ckpt["ema"])
            log.info(f"Loaded EMA from {args.ckpt}")
        except Exception as e:
            log.warning(f"EMA load failed: {e}")

    # Build val dataset (full val)
    data_opt = NS()
    data_opt.dataset_dir = Path(opt.dataset_dir)
    data_opt.image_size = opt.image_size
    data_opt.device = opt.device
    val_dataset = imagenet.build_lmdb_dataset(data_opt, log, train=False)

    from torch.utils.data import DataLoader
    val_loader = DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, drop_last=False
    )

    # Build corruption method (online)
    corrupt_method = build_corruption(opt, log, opt.corrupt)
    log.info(f"Using corruption: {opt.corrupt}")

    # Iterate val set
    total = len(val_loader)
    max_batches = args.max_batches if args.max_batches > 0 else total
    for bidx, (clean_img, y) in enumerate(val_loader):
        if bidx >= max_batches:
            break

        with torch.no_grad():
            corrupt_img = corrupt_method(clean_img.to(opt.device))
        mask = None
        cond = corrupt_img.detach() if opt.cond_x1 else None

        # Sampling
        xs, pred_x0s = ddpm_sampling(
            opt, net, ema, diffusion,
            x1=corrupt_img, mask=mask, cond=cond,
            clip_denoise=opt.clip_denoise, nfe=args.nfe, log_count=10, log=log
        )

        # Move to CPU
        clean_img   = clean_img.detach().cpu()
        corrupt_img = corrupt_img.detach().cpu()
        xs          = xs.detach().cpu()
        pred_x0s    = pred_x0s.detach().cpu()

        # Save grids
        H, W = clean_img.shape[-2:]
        it = bidx  # use batch index as step for filenames
        images_dir = out_root / "images"
        save_png_grid(images_dir, it, "image_clean",   clean_img, nrow=min(8, clean_img.shape[0]), log=log)
        save_png_grid(images_dir, it, "image_corrupt", corrupt_img, nrow=min(8, clean_img.shape[0]), log=log)
        save_png_grid(images_dir, it, "traj_pred_clean",
                      pred_x0s.reshape(-1, 3, H, W), nrow=pred_x0s.shape[1], log=log)
        save_png_grid(images_dir, it, "traj_recon",
                      xs.reshape(-1, 3, H, W), nrow=xs.shape[1], log=log)

        # Optional GIFs
        if args.save_gif:
            frames_dir = out_root / "frames"
            k_list = list(range(min(args.k_samples, xs.shape[0])))
            every = max(xs.shape[1] // 24, 1)
            save_traj_frames_and_gif(frames_dir, it, xs, pred_x0s, k_list=k_list, every=every, fps=8, flip_time=True, log=log)

        log.info(f"[{bidx+1}/{max_batches}] saved")

    log.info(f"Done. Outputs at: {out_root}")


if __name__ == "__main__":
    main()
