import os
import json
from typing import Optional

import torch
from torchvision.utils import save_image, make_grid

try:
    from generative.diffusion.diffusion_train import Conditional_UNet
except ModuleNotFoundError:
    import sys as _sys, os as _os
    _proj_root = _os.path.abspath(_os.path.join(_os.path.dirname(__file__), "..", ".."))
    if _proj_root not in _sys.path:
        _sys.path.insert(0, _proj_root)
    from generative.diffusion.diffusion_train import Conditional_UNet


@torch.no_grad()
def sample_fm(
    save_root: str,
    ckpt_subdir: Optional[str],
    ckpt_path: Optional[str],
    out_dir: Optional[str],
    steps: int = 50,
    gen_variant: str = "orig",
    n_per_class: int = 10,
    device_str: Optional[str] = None,
    dataset: Optional[str] = None,
) -> None:
    device = torch.device(device_str) if device_str else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    variant = (gen_variant or "orig").lower()
    if ckpt_path and os.path.isfile(ckpt_path):
        ckpt_path_resolved = ckpt_path
    else:
        candidates = []
        if ckpt_subdir:
            if ckpt_subdir.endswith(".pth") and os.path.isfile(ckpt_subdir):
                candidates.append(ckpt_subdir)
            else:
                if os.path.isabs(ckpt_subdir) or os.sep in ckpt_subdir:
                    candidates.append(os.path.join(ckpt_subdir, "checkpoint.pth"))
                if dataset in ("mnist", "cifar10"):
                    candidates.append(os.path.join(save_root, "params", dataset, variant, ckpt_subdir, "checkpoint.pth"))
        if dataset in ("mnist", "cifar10"):
            candidates.append(os.path.join(save_root, "params", dataset, variant, "checkpoint.pth"))
            other = "cifar10" if dataset == "mnist" else "mnist"
            candidates.append(os.path.join(save_root, "params", other, variant, "checkpoint.pth"))
        else:
            candidates.extend([
                os.path.join(save_root, "params", "cifar10", variant, "checkpoint.pth"),
                os.path.join(save_root, "params", "mnist", variant, "checkpoint.pth"),
            ])
        candidates.extend([
            os.path.join(save_root, "params", variant, "checkpoint.pth"),
            os.path.join(save_root, "params", "checkpoint.pth"),
        ])
        ckpt_path_resolved = None
        for cp in candidates:
            if os.path.isfile(cp):
                ckpt_path_resolved = cp
                break
    if not ckpt_path_resolved:
        raise FileNotFoundError("Cannot find checkpoint under params/")
    ckpt = torch.load(ckpt_path_resolved, map_location="cpu", weights_only=False)

    cfg = ckpt.get("config", {})
    t_emb_dim = int(cfg.get("t_emb_dim", 256))
    num_classes = int(cfg.get("num_classes", 10))
    image_size = int(cfg.get("image_size", 32))
    time_scale = float(cfg.get("time_scale", 1000.0))

    if out_dir is None:
        ds_name = str(cfg.get("dataset", "mnist"))
        out_dir = os.path.join(save_root, "samples", ds_name, variant)
    os.makedirs(out_dir, exist_ok=True)

    channels = int(cfg.get("channels", 1))
    base_width = int(cfg.get("base_width", 64))
    ds_name = str(cfg.get("dataset", "mnist"))
    if ds_name == "cifar10":
        G = Conditional_UNet(
            t_emb_dim=t_emb_dim,
            num_classes=num_classes,
            in_ch=channels,
            out_ch=channels,
            base_width=base_width,
            enable_sa1=True,
            enable_sa2=True,
            enable_sa3=False,
            enable_sa4=True,
            enable_sa5=True,
            enable_sa6=False,
        ).to(device)
    else:
        G = Conditional_UNet(t_emb_dim=t_emb_dim, num_classes=num_classes, in_ch=channels, out_ch=channels, base_width=base_width).to(device)
    G.load_state_dict(ckpt["model"])
    G.eval()

    h = 1.0 / float(steps)
    ts = torch.linspace(0.0, 1.0 - h, steps, device=device)  # t_0..t_{N-1}

    def one_round(labels: torch.Tensor) -> torch.Tensor:
        b = labels.size(0)
        x = torch.randn(b, channels, image_size, image_size, device=device)
        for i in range(steps):
            t = ts[i]
            t_scaled = (t * time_scale).repeat(b)
            v = G(x, t_scaled, labels)
            x = x + h * v
        return (x.clamp(-1, 1).cpu() + 1.0) * 0.5

    for c in range(num_classes):
        ys = torch.full((n_per_class,), c, dtype=torch.long, device=device)
        imgs = one_round(ys)
        class_dir = os.path.join(out_dir, f"class_{c:02d}")
        os.makedirs(class_dir, exist_ok=True)
        for i in range(n_per_class):
            save_image(imgs[i], os.path.join(class_dir, f"img_{i:02d}.png"))
        grid = make_grid(imgs, nrow=n_per_class, padding=2)
        save_image(grid, os.path.join(out_dir, f"class_{c:02d}_grid.png"))

    if num_classes >= 10:
        ys = torch.arange(10, device=device).repeat_interleave(10)
        imgs = one_round(ys)
        grid_all = make_grid(imgs, nrow=10, padding=2)
        save_image(grid_all, os.path.join(out_dir, "grid_all.png"))

    print(f"Saved FM-new samples under: {out_dir}", flush=True)


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_root", type=str, default=os.path.dirname(__file__))
    parser.add_argument("--ckpt_subdir", type=str, default=None)
    parser.add_argument("--ckpt_path", type=str, default=None)
    parser.add_argument("--out_dir", type=str, default=None)
    parser.add_argument("--steps", type=int, default=50)
    parser.add_argument("--gen_variant", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--per_class", type=int, default=10)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist","cifar10"])
    args = parser.parse_args()

    # Set default steps based on dataset
    if args.dataset == "cifar10" and args.steps == 50:  # default value
        args.steps = 200

    sample_fm(
        save_root=args.save_root,
        ckpt_subdir=args.ckpt_subdir,
        ckpt_path=args.ckpt_path,
        out_dir=args.out_dir,
        steps=args.steps,
        gen_variant=args.gen_variant,
        n_per_class=args.per_class,
        device_str=args.device,
        dataset=args.dataset,
    )


if __name__ == "__main__":
    main()


