import os
import json
from typing import Optional

import torch
from torchvision.utils import save_image, make_grid

try:
    from generative.diffusion.diffusion_train import Conditional_UNet
except ModuleNotFoundError:
    import sys as _sys, os as _os
    _proj_root = _os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
    if _proj_root not in _sys.path:
        _sys.path.insert(0, _proj_root)
    from generative.diffusion.diffusion_train import Conditional_UNet


@torch.no_grad()
def sample_ddim(
    save_root: str,
    ckpt_subdir: Optional[str],
    ckpt_path: Optional[str],
    out_dir: Optional[str],
    gen_variant: str = "orig",
    steps: int = 200,
    n_per_class: int = 10,
    eta: float = 0.0,
    index_schedule: str = "linear",
    device_str: Optional[str] = None,
    dataset: Optional[str] = None,
) -> None:

    device = torch.device(device_str) if device_str else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    variant = (gen_variant or "orig").lower()

    if ckpt_path and os.path.isfile(ckpt_path):
        ckpt_path_resolved = ckpt_path
    else:
        candidates = []
        if ckpt_subdir:
            if ckpt_subdir.endswith(".pth") and os.path.isfile(ckpt_subdir):
                candidates.append(ckpt_subdir)
            else:
                if os.path.isabs(ckpt_subdir) or os.sep in ckpt_subdir:
                    candidates.append(os.path.join(ckpt_subdir, "checkpoint.pth"))
                if dataset in ("mnist", "cifar10"):
                    candidates.append(os.path.join(save_root, "params", dataset, variant, ckpt_subdir, "checkpoint.pth"))
        if dataset in ("mnist", "cifar10"):
            candidates.append(os.path.join(save_root, "params", dataset, variant, "checkpoint.pth"))
            other = "cifar10" if dataset == "mnist" else "mnist"
            candidates.append(os.path.join(save_root, "params", other, variant, "checkpoint.pth"))
        else:
            candidates.extend([
                os.path.join(save_root, "params", "cifar10", variant, "checkpoint.pth"),
                os.path.join(save_root, "params", "mnist", variant, "checkpoint.pth"),
            ])
        candidates.extend([
            os.path.join(save_root, "params", variant, "checkpoint.pth"),
            os.path.join(save_root, "params", "checkpoint.pth"),
        ])
        ckpt_path_resolved = None
        for cp in candidates:
            if os.path.isfile(cp):
                ckpt_path_resolved = cp
                break
    if not ckpt_path_resolved:
        raise FileNotFoundError("Cannot find checkpoint under params/")
    ckpt = torch.load(ckpt_path_resolved, map_location="cpu", weights_only=False)
    betas = ckpt["betas"].to(device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)

    cfg = ckpt.get("config", {})
    t_emb_dim = int(cfg.get("t_emb_dim", 256))
    num_classes = int(cfg.get("num_classes", 10))
    T = int(cfg.get("T", int(len(betas))))

    if index_schedule == "alpha":
        a_bar = alphas_cumprod
        targets = torch.linspace(1.0, float(a_bar[-1].item()), steps, device=device)
        diffs = (a_bar.view(1, -1) - targets.view(-1, 1)).abs()
        idx = torch.argmin(diffs, dim=1)
        idx, _ = torch.sort(idx, descending=True)
        idx = torch.unique_consecutive(idx)
    else:
        idx = torch.linspace(T - 1, 0, steps, device=device).round().long()
        idx = torch.unique_consecutive(idx)
    steps = int(idx.numel())
    a_t = alphas_cumprod[idx]
    a_next = torch.cat([alphas_cumprod[idx[1:]], torch.tensor([1.0], device=device)])

    if out_dir is None:
        ds_name = str(cfg.get("dataset", "mnist"))
        out_dir = os.path.join(save_root, "samples", ds_name, variant, "ddim")
    os.makedirs(out_dir, exist_ok=True)

    # Load model
    channels = int(cfg.get("channels", 1))
    base_width = int(cfg.get("base_width", 64))
    ds_name = str(cfg.get("dataset", "mnist"))
    if ds_name == "cifar10":
        G = Conditional_UNet(
            t_emb_dim=t_emb_dim,
            num_classes=num_classes,
            in_ch=channels,
            out_ch=channels,
            base_width=base_width,
            enable_sa1=True,
            enable_sa2=True,
            enable_sa3=False,
            enable_sa4=True,
            enable_sa5=True,
            enable_sa6=False,
        ).to(device)
    else:
        G = Conditional_UNet(t_emb_dim=t_emb_dim, num_classes=num_classes, in_ch=channels, out_ch=channels, base_width=base_width).to(device)
    G.load_state_dict(ckpt["model"])
    G.eval()

    def one_round(labels: torch.Tensor) -> torch.Tensor:
        b = labels.size(0)
        x = torch.randn(b, channels, 32, 32, device=device)

        for i in range(steps):
            a_cur = a_t[i].view(1, 1, 1, 1)
            a_nxt = a_next[i].view(1, 1, 1, 1)
            t = idx[i].float().repeat(b)

            eps = G(x, t, labels)
            x0 = (x - torch.sqrt(1.0 - a_cur) * eps) / torch.sqrt(a_cur)

            if ds_name == "cifar10":
                x0 = x0.clamp(-1.0, 1.0)

            if i < steps - 1:
                sigma = eta * torch.sqrt((1.0 - a_cur / a_nxt) * (1.0 - a_nxt) / (1.0 - a_cur))

                noise = torch.randn_like(x)

                if ds_name == "cifar10":
                    coef = (1.0 - a_nxt - sigma**2).clamp(min=0.0)
                else:
                    coef = 1.0 - a_nxt - sigma**2

                x = torch.sqrt(a_nxt) * x0 + torch.sqrt(coef) * eps + sigma * noise
            else:
                x = x0

        return (x.clamp(-1.0, 1.0).cpu() + 1.0) * 0.5

    for c in range(num_classes):
        ys = torch.full((n_per_class,), c, dtype=torch.long, device=device)
        imgs = one_round(ys)
        class_dir = os.path.join(out_dir, f"class_{c:02d}")
        os.makedirs(class_dir, exist_ok=True)
        for i in range(n_per_class):
            save_image(imgs[i], os.path.join(class_dir, f"img_{i:02d}.png"))
        grid = make_grid(imgs, nrow=n_per_class, padding=2)
        save_image(grid, os.path.join(out_dir, f"class_{c:02d}_grid.png"))

    if num_classes >= 10:
        ys = torch.arange(10, device=device).repeat_interleave(10)
        imgs = one_round(ys)
        grid_all = make_grid(imgs, nrow=10, padding=2)
        save_image(grid_all, os.path.join(out_dir, "grid_all.png"))

    print(f"Saved DDIM-new samples under: {out_dir}", flush=True)


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_root", type=str, default=os.path.dirname(__file__))
    parser.add_argument("--ckpt_subdir", type=str, default=None)
    parser.add_argument("--ckpt_path", type=str, default=None)
    parser.add_argument("--out_dir", type=str, default=None)
    parser.add_argument("--gen_variant", type=str, default="orig", choices=["orig", "aug"])
    parser.add_argument("--steps", type=int, default=250)
    parser.add_argument("--per_class", type=int, default=10)
    parser.add_argument("--eta", type=float, default=0.0)
    parser.add_argument("--index_schedule", type=str, default="linear", choices=["linear","alpha"])
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--dataset", type=str, default="mnist", choices=["mnist","cifar10"])
    args = parser.parse_args()

    # Set default eta and index_schedule based on dataset
    if args.dataset == "cifar10":
        if args.eta == 0.0:
            args.eta = 0.5
        if args.index_schedule == "linear":
            args.index_schedule = "alpha"

    sample_ddim(
        save_root=args.save_root,
        ckpt_subdir=args.ckpt_subdir,
        ckpt_path=args.ckpt_path,
        out_dir=args.out_dir,
        gen_variant=args.gen_variant,
        steps=args.steps,
        n_per_class=args.per_class,
        eta=args.eta,
        index_schedule=args.index_schedule,
        device_str=args.device,
        dataset=args.dataset,
    )


if __name__ == "__main__":
    main()


