#!/usr/bin/env python3
import os
import re
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
from diffusers import DDPMScheduler, UNet2DModel

# MUX 모델이 diffusers에 등록돼 있다면 임포트
try:
    from diffusers import MUXUNet2DModel
    HAS_MUX = True
except Exception:
    HAS_MUX = False


# -----------------------
# Utils
# -----------------------
def find_latest_checkpoint(directory: str):
    """
    output_dir 아래에서 가장 숫자가 큰 `checkpoint-XXXX` 디렉토리명을 반환
    """
    if not os.path.isdir(directory):
        raise FileNotFoundError(f"{directory} is not a directory.")
    pattern = re.compile(r"^checkpoint-(\d+)$")
    max_num, latest = -1, None
    for name in os.listdir(directory):
        m = pattern.match(name)
        if m:
            num = int(m.group(1))
            if num > max_num:
                max_num, latest = num, name
    if latest is None:
        raise FileNotFoundError(f"No checkpoint-* found under: {directory}")
    return latest


def to_uint8_nhwc_from_minus1_1(x: torch.Tensor) -> np.ndarray:
    """
    x: BCHW, in [-1, 1]
    -> NHWC uint8 [0,255]
    """
    x = x.detach().cpu().clamp(-1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(0, 2, 3, 1).numpy()
    x = (255 * x).astype(np.uint8)
    return x


def save_images_tensor(x: torch.Tensor, start_idx: int, output_dir: str) -> int:
    """
    x: BCHW, in [-1,1]
    start_idx: 파일명 시작 인덱스
    output_dir: png 저장 경로
    return: 새로운 start_idx
    """
    arr = to_uint8_nhwc_from_minus1_1(x)
    for i in range(arr.shape[0]):
        Image.fromarray(arr[i]).save(os.path.join(output_dir, f"{start_idx+i:05}.png"))
    return start_idx + arr.shape[0]


def make_grid_pil(bchw: torch.Tensor, nrow: int = 5, pad: int = 2, pad_val: int = 255) -> Image.Image:
    """
    bchw [-1,1] 텐서를 0~255 NHWC로 바꾸고 nrow x ncol 그리드로 PIL 이미지 반환
    nrow=5, B=20이면 4x5 그리드
    """
    arr = to_uint8_nhwc_from_minus1_1(bchw)  # [B,H,W,C], uint8
    B, H, W, C = arr.shape
    ncol = (B + nrow - 1) // nrow  # 행 수
    grid_h = ncol * H + (ncol + 1) * pad
    grid_w = nrow * W + (nrow + 1) * pad
    canvas = np.full((grid_h, grid_w, C), pad_val, dtype=np.uint8)

    idx = 0
    for r in range(ncol):
        for c in range(nrow):
            if idx >= B:
                break
            y0 = pad + r * (H + pad)
            x0 = pad + c * (W + pad)
            canvas[y0:y0+H, x0:x0+W, :] = arr[idx]
            idx += 1

    return Image.fromarray(canvas)


def save_grid(bchw: torch.Tensor, out_path: str, nrow: int = 5):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    grid = make_grid_pil(bchw, nrow=nrow)
    grid.save(out_path)


def load_unet_from_ckpt(ckpt_root: str, mux: bool):
    """
    ckpt_root/output_dir 구조에서 최신 checkpoint의 unet_ema를 로드
    mux=True면 MUXUNet2DModel 시도
    """
    ckpt_dirname = find_latest_checkpoint(ckpt_root)
    unet_ema_path = os.path.join(ckpt_root, ckpt_dirname, "unet_ema")
    if mux:
        if not HAS_MUX:
            raise ImportError("MUXUNet2DModel is not available in your diffusers install.")
        model = MUXUNet2DModel.from_pretrained(unet_ema_path)
    else:
        model = UNet2DModel.from_pretrained(unet_ema_path)
    return model


# -----------------------
# Args
# -----------------------
def parse_args():
    p = argparse.ArgumentParser(description="DDPM inference with step-wise K switching (K=4→K=2→K=1) and progressive grids every 50 steps.")

    # ✅ 기본 경로
    p.add_argument("--model_path_k1", type=str,
                   default="/data/baek/ddpm_checkpoint/Celaba",
                   help="K=1 모델 output_dir (체크포인트 루트)")
    p.add_argument("--model_path_k2", type=str,
                   default="/data/baek/ddpm_checkpoint/mux2_celaba",
                   help="K=2 모델 output_dir (체크포인트 루트, MUX)")
    p.add_argument("--model_path_k4", type=str,
                   default="/data/baek/ddpm_checkpoint/mux4_celeba",
                   help="K=4 모델 output_dir (체크포인트 루트, MUX)")

    # 샘플링/저장 관련
    p.add_argument("--n_samples", type=int, default=20, help="총 생성할 이미지 수 (그리드는 20 권장)")
    p.add_argument("--batch_size", type=int, default=None, help="배치 크기 (기본: n_samples로 자동 설정)")
    p.add_argument("--output_dir", type=str, default="/data/minkyu/ddpm", help="부모 출력 디렉토리")
    p.add_argument("--run_name", type=str, default="mixedK", help="출력 하위 폴더명")

    # 스케줄러/해상도
    p.add_argument("--scheduler_path", type=str, default=None,
                   help="DDPMScheduler 디렉토리 (기본: K=1 모델의 scheduler)")
    p.add_argument("--num_inference_steps", type=int, default=1000,
                   help="DDPM 스텝 수. 1000 기준(0~999)으로 K 구간이 설정됨")

    # 그리드 저장 인터벌
    p.add_argument("--save_every", type=int, default=50, help="몇 스텝마다 프로그레시브 그리드 저장할지")

    # 기타
    p.add_argument("--fp16", action="store_true")
    p.add_argument("--seed", type=int, default=None)

    return p.parse_args()


# -----------------------
# Main
# -----------------------
@torch.no_grad()
def main():
    args = parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float16 if (args.fp16 and device.type == "cuda") else torch.float32

    # 스케줄러 로드 (기본: K=1의 scheduler 사용)
    scheduler_root = args.scheduler_path or os.path.join(args.model_path_k1, "scheduler")
    scheduler = DDPMScheduler.from_pretrained(scheduler_root)

    # 모델 3개 로드
    unet_k1 = load_unet_from_ckpt(args.model_path_k1, mux=False).to(device)
    unet_k2 = load_unet_from_ckpt(args.model_path_k2, mux=True).to(device)
    unet_k4 = load_unet_from_ckpt(args.model_path_k4, mux=True).to(device)
    for m in (unet_k1, unet_k2, unet_k4):
        m.eval()

    # 출력 디렉토리
    sample_output_path = os.path.join(args.output_dir, args.run_name)
    pil_dir = os.path.join(sample_output_path, "pil")
    pt_dir = os.path.join(sample_output_path, "pt")
    npz_dir = os.path.join(sample_output_path, "numpz")
    prog_dir = os.path.join(sample_output_path, "progress")  # 🔥 프로그레시브 그리드
    os.makedirs(pil_dir, exist_ok=True)
    os.makedirs(pt_dir, exist_ok=True)
    os.makedirs(npz_dir, exist_ok=True)
    os.makedirs(prog_dir, exist_ok=True)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    # 배치 크기 고정: 그리드 목적이라 n_samples와 동일 권장
    batch_size = args.batch_size or args.n_samples

    total_to_save = args.n_samples
    saved = 0
    all_pt_batches = []
    all_np_batches = []

    # 타임스텝 설정 (diffusers는 보통 큰 t→작은 t로 내려옴)
    scheduler.set_timesteps(args.num_inference_steps, device=device)
    timesteps = scheduler.timesteps  # e.g., tensor([999, 998, ..., 0])

    # K-스위칭 규칙 (원 코드 기준)
    def pick_unet(t_int: int):
        # if t_int >= 500:
        #     return unet_k4, 4
        # elif t_int >= 250:
        #     return unet_k2, 2
        # else:
        #     return unet_k1, 1
        return unet_k4, 4

    pbar = tqdm(total=total_to_save, desc="Generated images")
    while saved < total_to_save:
        bs = min(batch_size, total_to_save - saved)

        # 초기 노이즈
        if isinstance(unet_k1.config.sample_size, int):
            image_shape = (bs, unet_k1.config.in_channels,
                           unet_k1.config.sample_size, unet_k1.config.sample_size)
        else:
            image_shape = (bs, unet_k1.config.in_channels, *unet_k1.config.sample_size)
        latents = torch.randn(image_shape, device=device, dtype=dtype)

        # 🔎 초기 상태(노이즈)도 보고 싶다면 주석 해제
        save_grid(latents, os.path.join(prog_dir, f"progress_idx0000_t{int(timesteps[0]):04}.png"), nrow=5)

        # 타임스텝 루프
        with tqdm(total=len(timesteps),
                  desc=f"DDPM steps (batch {saved//bs + 1})",
                  leave=False) as tbar:
            for idx, t in enumerate(timesteps):
                t_int = int(t)
                current_unet, current_k = pick_unet(t_int)

                # U-Net 추론
                noise_pred = current_unet(latents, t).sample

                # DDPM 1스텝 업데이트
                step_out = scheduler.step(noise_pred, t, latents)
                latents = step_out.prev_sample  # 여전히 [-1,1] 범위의 x_{t-1} 형태

                # 🔥 50 스텝마다 프로그레시브 그리드 저장
                if (idx % args.save_every) == 0 or (idx == len(timesteps) - 1):
                    out_name = os.path.join(
                        prog_dir, f"progress_idx{idx:04}_t{t_int:04}.png"
                    )
                    save_grid(latents, out_name, nrow=5)

                tbar.update(1)

        # 최종 결과 개별 저장(원래 로직 유지)
        saved = save_images_tensor(latents, saved, pil_dir)
        latent_cpu = latents.detach().cpu()
        all_pt_batches.append(latents)

        # npz용 NHWC uint8
        latent_uint8 = ((latent_cpu + 1) * 127.5).clamp(0, 255).to(torch.uint8)
        latent_uint8 = latent_uint8.permute(0, 2, 3, 1).contiguous().numpy()
        all_np_batches.append(latent_uint8)

        pbar.update(bs)

    pbar.close()

    # pt / npz 저장
    pt_tensor = torch.stack(all_pt_batches)  # [num_batches, B, C, H, W]
    torch.save(pt_tensor, os.path.join(pt_dir, f"{args.n_samples}-samples.pt"))

    imgs = np.concatenate(all_np_batches, axis=0)[:args.n_samples]
    np.savez(os.path.join(npz_dir, f"{args.n_samples}-samples.npz"), imgs)
    print(f"Saved {imgs.shape} images to {npz_dir}")
    print(f"Progressive grids saved to {prog_dir}")


if __name__ == "__main__":
    main()
