import os
import json
from typing import Optional

import torch
from torchvision.utils import save_image, make_grid

try:
    from generative.diffusion.diffusion_train import Conditional_UNet
except ModuleNotFoundError:
    import sys as _sys, os as _os
    _proj_root = _os.path.abspath(_os.path.join(_os.path.dirname(__file__), "..", ".."))
    if _proj_root not in _sys.path:
        _sys.path.insert(0, _proj_root)
    from generative.diffusion.diffusion_train import Conditional_UNet


@torch.no_grad()
def sample_ddpm(
    save_root: str,
    ckpt_subdir: Optional[str],
    ckpt_path: Optional[str],
    out_dir: Optional[str],
    gen_variant: str = "orig",
    n_per_class: int = 10,
    device_str: Optional[str] = None,
    dataset: Optional[str] = None,
) -> None:
    device = torch.device(device_str) if device_str else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    variant = (gen_variant or "orig").lower()
    if ckpt_path and os.path.isfile(ckpt_path):
        ckpt_path_resolved = ckpt_path
    else:
        candidates = []
        if ckpt_subdir:
            if ckpt_subdir.endswith(".pth") and os.path.isfile(ckpt_subdir):
                candidates.append(ckpt_subdir)
            else:
                if os.path.isabs(ckpt_subdir) or os.sep in ckpt_subdir:
                    candidates.append(os.path.join(ckpt_subdir, "checkpoint.pth"))
                if dataset in ("mnist", "cifar10"):
                    candidates.append(os.path.join(save_root, "params", dataset, variant, ckpt_subdir, "checkpoint.pth"))
        if dataset in ("mnist", "cifar10"):
            candidates.append(os.path.join(save_root, "params", dataset, variant, "checkpoint.pth"))
            other = "cifar10" if dataset == "mnist" else "mnist"
            candidates.append(os.path.join(save_root, "params", other, variant, "checkpoint.pth"))
        else:
            candidates.extend([
                os.path.join(save_root, "params", "cifar10", variant, "checkpoint.pth"),
                os.path.join(save_root, "params", "mnist", variant, "checkpoint.pth"),
            ])
        candidates.extend([
            os.path.join(save_root, "params", variant, "checkpoint.pth"),
            os.path.join(save_root, "params", "checkpoint.pth"),
        ])
        ckpt_path_resolved = None
        for cp in candidates:
            if os.path.isfile(cp):
                ckpt_path_resolved = cp
                break
    if not ckpt_path_resolved:
        raise FileNotFoundError("Cannot find checkpoint under params/")
    ckpt = torch.load(ckpt_path_resolved, map_location="cpu", weights_only=False)
    betas = ckpt["betas"].to(device)  # (T,)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0)
    posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
    posterior_log_variance_clipped = torch.log(posterior_variance.clamp(min=1e-20))
    posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
    posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)

    cfg = ckpt.get("config", {})
    t_emb_dim = int(cfg.get("t_emb_dim", 256))
    num_classes = int(cfg.get("num_classes", 10))
    T = int(cfg.get("T", int(len(betas))))
    ds_name = str(cfg.get("dataset", "mnist"))

    if out_dir is None:
        ds_name = str(cfg.get("dataset", "mnist"))
        out_dir = os.path.join(save_root, "samples", ds_name, variant, "ddpm")
    os.makedirs(out_dir, exist_ok=True)

    # Load model
    channels = int(cfg.get("channels", 1))
    base_width = int(cfg.get("base_width", 64))
    if ds_name == "cifar10":
        G = Conditional_UNet(
            t_emb_dim=t_emb_dim,
            num_classes=num_classes,
            in_ch=channels,
            out_ch=channels,
            base_width=base_width,
            enable_sa1=True,
            enable_sa2=True,
            enable_sa3=False,
            enable_sa4=True,
            enable_sa5=True,
            enable_sa6=False,
        ).to(device)
    else:
        G = Conditional_UNet(t_emb_dim=t_emb_dim, num_classes=num_classes, in_ch=channels, out_ch=channels, base_width=base_width).to(device)
    G.load_state_dict(ckpt["model"])
    G.eval()

    def one_round(labels: torch.Tensor) -> torch.Tensor:
        b = labels.size(0)
        x = torch.randn(b, channels, 32, 32, device=device)
        for i in range(T - 1, -1, -1):
            t = torch.full((b,), float(i), device=device)
            pred_noise = G(x, t, labels)
            if ds_name == "cifar10":
                a_bar_t = alphas_cumprod[i].view(1, 1, 1, 1)
                x0 = (x - torch.sqrt(1.0 - a_bar_t) * pred_noise) / torch.sqrt(a_bar_t)
                pmc1 = posterior_mean_coef1[i].view(1, 1, 1, 1)
                pmc2 = posterior_mean_coef2[i].view(1, 1, 1, 1)
                post_mean = pmc1 * x0 + pmc2 * x
                if i > 0:
                    noise = torch.randn_like(x)
                    post_std = torch.exp(0.5 * posterior_log_variance_clipped[i]).view(1, 1, 1, 1)
                    x = post_mean + post_std * noise
                else:
                    x = post_mean
            else:
                alpha_t = alphas[i].view(1, 1, 1, 1)
                alpha_bar_t = alphas_cumprod[i].view(1, 1, 1, 1)
                beta_t = betas[i].view(1, 1, 1, 1)
                noise = torch.randn_like(x) if i > 0 else torch.zeros_like(x)
                x = (1.0 / torch.sqrt(alpha_t)) * (x - ((1.0 - alpha_t) / torch.sqrt(1.0 - alpha_bar_t)) * pred_noise) \
                    + torch.sqrt(beta_t) * noise
        return (x.clamp(-1, 1).cpu() + 1.0) * 0.5

    for c in range(num_classes):
        ys = torch.full((n_per_class,), c, dtype=torch.long, device=device)
        imgs = one_round(ys)
        class_dir = os.path.join(out_dir, f"class_{c:02d}")
        os.makedirs(class_dir, exist_ok=True)
        for i in range(n_per_class):
            save_image(imgs[i], os.path.join(class_dir, f"img_{i:02d}.png"))
        grid = make_grid(imgs, nrow=n_per_class, padding=2)
        save_image(grid, os.path.join(out_dir, f"class_{c:02d}_grid.png"))

    if num_classes >= 10:
        ys = torch.arange(10, device=device).repeat_interleave(10)
        imgs = one_round(ys)
        grid_all = make_grid(imgs, nrow=10, padding=2)
        save_image(grid_all, os.path.join(out_dir, "grid_all.png"))

    print(f"Saved DDPM-new samples under: {out_dir}", flush=True)


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_root", type=str, default=os.path.dirname(__file__))
    parser.add_argument("--ckpt_subdir", type=str, default=None)
    parser.add_argument("--ckpt_path", type=str, default=None)
    parser.add_argument("--out_dir", type=str, default=None)
    parser.add_argument("--gen_variant", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--per_class", type=int, default=10)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist", "cifar10"])
    args = parser.parse_args()

    sample_ddpm(
        save_root=args.save_root,
        ckpt_subdir=args.ckpt_subdir,
        ckpt_path=args.ckpt_path,
        out_dir=args.out_dir,
        gen_variant=args.gen_variant,
        n_per_class=args.per_class,
        device_str=args.device,
        dataset=args.dataset,
    )


if __name__ == "__main__":
    main()


