import os
import json
from typing import Optional

import torch
from torchvision.utils import save_image, make_grid
import argparse

try:
    from generative.gan.gan_train import Gen
except ModuleNotFoundError:
    import sys as _sys, os as _os
    _proj_root = _os.path.abspath(_os.path.join(_os.path.dirname(__file__), "..", ".."))
    if _proj_root not in _sys.path:
        _sys.path.insert(0, _proj_root)
    from generative.gan.gan_train import Gen


@torch.no_grad()
def generate_per_class(
    save_root: str,
    ckpt_subdir: Optional[str],
    out_dir: Optional[str],
    gen_variant: str = "orig",
    n_per_class: int = 10,
    z_dim: Optional[int] = None,
    device_str: Optional[str] = None,
    dataset: Optional[str] = None,
) -> None:
    device = torch.device(device_str) if device_str else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    variant = (gen_variant or "orig").lower()
    candidates = []
    if dataset in ("mnist", "cifar10"):
        candidates.append(os.path.join(save_root, "params", dataset, variant, "checkpoint.pth"))
        other = "cifar10" if dataset == "mnist" else "mnist"
        candidates.append(os.path.join(save_root, "params", other, variant, "checkpoint.pth"))
    else:
        candidates.extend([
            os.path.join(save_root, "params", "cifar10", variant, "checkpoint.pth"),
            os.path.join(save_root, "params", "mnist", variant, "checkpoint.pth"),
        ])
    candidates.extend([
        os.path.join(save_root, "params", variant, "checkpoint.pth"),
        os.path.join(save_root, "params", "checkpoint.pth"),
    ])
    ckpt_path = None
    for cp in candidates:
        if os.path.isfile(cp):
            ckpt_path = cp
            break
    if ckpt_path is None:
        raise FileNotFoundError("Cannot find checkpoint under params/")
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

    if z_dim is None:
        cfg = ckpt.get("config", {})
        z_dim = int(cfg.get("z_dim", 128))
    else:
        cfg = ckpt.get("config", {})

    if out_dir is None:
        ds_name = str(cfg.get("dataset", "mnist"))
        out_dir = os.path.join(save_root, "samples", ds_name, variant)
    os.makedirs(out_dir, exist_ok=True)

    print(f"Loading generator from checkpoint: {ckpt_path}")
    img_ch = int(cfg.get("channels", 1))
    base_hw = int(cfg.get("base_hw", 7))
    width_mult = float(cfg.get("width_mult", 1.0))
    G = Gen(z_dim=z_dim, out_ch=img_ch, base_hw=base_hw, width_mult=width_mult).to(device)
    G.load_state_dict(ckpt["G"])
    G.eval()

    for c in range(10):
        zs = torch.randn(n_per_class, z_dim, device=device)
        ys = torch.full((n_per_class,), c, dtype=torch.long, device=device)
        imgs = G(zs, ys).cpu()
        imgs = (imgs + 1.0) * 0.5

        class_dir = os.path.join(out_dir, f"class_{c:02d}")
        os.makedirs(class_dir, exist_ok=True)
        for i in range(n_per_class):
            save_image(imgs[i], os.path.join(class_dir, f"img_{i:02d}.png"))

        grid = make_grid(imgs, nrow=n_per_class, normalize=False, padding=2)
        save_image(grid, os.path.join(out_dir, f"class_{c:02d}_grid.png"))

    zs = torch.randn(100, z_dim, device=device)
    ys = torch.arange(10, device=device).repeat_interleave(10)
    imgs = G(zs, ys).cpu()
    imgs = (imgs + 1.0) * 0.5
    grid_all = make_grid(imgs, nrow=10, normalize=False, padding=2)
    save_image(grid_all, os.path.join(out_dir, "grid_all.png"))

    print(f"Saved per-class {n_per_class} images and grids under: {out_dir}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_root", type=str, default=os.path.dirname(__file__))
    parser.add_argument("--ckpt_subdir", type=str, default=None)
    parser.add_argument("--out_dir", type=str, default=None)
    parser.add_argument("--gen_variant", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--per_class", type=int, default=10)
    parser.add_argument("--z_dim", type=int, default=None)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist","cifar10"])
    args = parser.parse_args()

    generate_per_class(
        save_root=args.save_root,
        ckpt_subdir=args.ckpt_subdir,
        out_dir=args.out_dir,
        gen_variant=args.gen_variant,
        n_per_class=args.per_class,
        z_dim=args.z_dim,
        device_str=args.device,
        dataset=args.dataset,
    )


if __name__ == "__main__":
    main()


