#!/usr/bin/env python3
import os
import re
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
from diffusers import DDPMScheduler, UNet2DModel

# MUX 모델이 diffusers에 등록돼 있다면 임포트
try:
    from diffusers import MUXUNet2DModel
    HAS_MUX = True
except Exception:
    HAS_MUX = False


# -----------------------
# Utils
# -----------------------
def find_latest_checkpoint(directory: str):
    """
    output_dir 아래에서 가장 숫자가 큰 `checkpoint-XXXX` 디렉토리명을 반환
    """
    if not os.path.isdir(directory):
        raise FileNotFoundError(f"{directory} is not a directory.")
    pattern = re.compile(r"^checkpoint-(\d+)$")
    max_num, latest = -1, None
    for name in os.listdir(directory):
        m = pattern.match(name)
        if m:
            num = int(m.group(1))
            if num > max_num:
                max_num, latest = num, name
    if latest is None:
        raise FileNotFoundError(f"No checkpoint-* found under: {directory}")
    return latest


def to_uint8_nhwc_from_minus1_1(x: torch.Tensor) -> np.ndarray:
    """
    x: BCHW, in [-1, 1]
    -> NHWC uint8 [0,255]
    """
    x = x.detach().cpu().clamp(-1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(0, 2, 3, 1).numpy()
    x = (255 * x).astype(np.uint8)
    return x


def save_images_tensor(x: torch.Tensor, start_idx: int, output_dir: str) -> int:
    """
    x: BCHW, in [-1,1]
    start_idx: 파일명 시작 인덱스
    output_dir: png 저장 경로
    return: 새로운 start_idx
    """
    arr = to_uint8_nhwc_from_minus1_1(x)
    for i in range(arr.shape[0]):
        Image.fromarray(arr[i]).save(os.path.join(output_dir, f"{start_idx+i:05}.png"))
    return start_idx + arr.shape[0]


def load_unet_from_ckpt(ckpt_root: str, mux: bool):
    """
    ckpt_root/output_dir 구조에서 최신 checkpoint의 unet_ema를 로드
    mux=True면 MUXUNet2DModel 시도
    """
    ckpt_dirname = find_latest_checkpoint(ckpt_root)
    unet_ema_path = os.path.join(ckpt_root, ckpt_dirname, "unet_ema")
    if mux:
        if not HAS_MUX:
            raise ImportError("MUXUNet2DModel is not available in your diffusers install.")
        model = MUXUNet2DModel.from_pretrained(unet_ema_path)
    else:
        model = UNet2DModel.from_pretrained(unet_ema_path)
    return model


# -----------------------
# Args
# -----------------------
def parse_args():
    p = argparse.ArgumentParser(description="DDPM inference with step-wise K switching (K=4→K=2→K=1).")

    # ✅ 기본값을 네가 준 경로로 설정 (필요시 인자로 교체 가능)
    p.add_argument("--model_path_k1", type=str,
                   default="/data/baek/ddpm_checkpoint/Celaba",
                   help="K=1 모델 output_dir (체크포인트 루트)")
    p.add_argument("--model_path_k2", type=str,
                   default="/data/baek/ddpm_checkpoint/mux2_celaba",
                   help="K=2 모델 output_dir (체크포인트 루트, MUX)")
    p.add_argument("--model_path_k4", type=str,
                   default="/data/baek/ddpm_checkpoint/mux4_celeba",
                   help="K=4 모델 output_dir (체크포인트 루트, MUX)")

    # 샘플링/저장 관련
    p.add_argument("--n_samples", type=int, default=64, help="총 생성할 이미지 수")
    p.add_argument("--batch_size", type=int, default=400, help="배치 크기")
    p.add_argument("--output_dir", type=str, default="/data/minkyu/ddpm", help="부모 출력 디렉토리")
    p.add_argument("--run_name", type=str, default="mixedK", help="출력 하위 폴더명")

    # 스케줄러/해상도
    p.add_argument("--scheduler_path", type=str, default=None,
                   help="DDPMScheduler 디렉토리 (기본: K=1 모델의 scheduler)")
    p.add_argument("--num_inference_steps", type=int, default=1000,
                   help="DDPM 스텝 수. 1000 기준(0~999)으로 K 구간이 설정됨")

    # 기타
    p.add_argument("--fp16", action="store_true")
    p.add_argument("--seed", type=int, default=None)

    return p.parse_args()


# -----------------------
# Main
# -----------------------
@torch.no_grad()
def main():
    args = parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.float16 if (args.fp16 and device.type == "cuda") else torch.float32

    # 스케줄러 로드 (기본: K=1의 scheduler 사용)
    scheduler_root = args.scheduler_path or os.path.join(args.model_path_k1, "scheduler")
    scheduler = DDPMScheduler.from_pretrained(scheduler_root)

    # 모델 3개 로드
    unet_k1 = load_unet_from_ckpt(args.model_path_k1, mux=False).to(device)
    unet_k2 = load_unet_from_ckpt(args.model_path_k2, mux=True).to(device)
    unet_k4 = load_unet_from_ckpt(args.model_path_k4, mux=True).to(device)
    for m in (unet_k1, unet_k2, unet_k4):
        m.eval()

    # 출력 디렉토리
    sample_output_path = os.path.join(args.output_dir, args.run_name)
    pil_dir = os.path.join(sample_output_path, "pil")
    pt_dir = os.path.join(sample_output_path, "pt")
    npz_dir = os.path.join(sample_output_path, "numpz")
    os.makedirs(pil_dir, exist_ok=True)
    os.makedirs(pt_dir, exist_ok=True)
    os.makedirs(npz_dir, exist_ok=True)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    total_to_save = args.n_samples
    saved = 0
    all_pt_batches = []
    all_np_batches = []

    # 타임스텝 설정 (diffusers는 보통 큰 t→작은 t로 내려옴)
    scheduler.set_timesteps(args.num_inference_steps, device=device)

    # 구간 정의(0~999 기준):
    # 0–700  → t >= 300  → K=4
    # 700–900 → 100 <= t < 300 → K=2
    # 900–1000 → t < 100 → K=1
    def pick_unet(t_int: int):
        if t_int >=900:
            return unet_k1, 1
        elif t_int >= 600:
            return unet_k2, 2
        elif t_int >= 200:
            return unet_k4, 4
        else:
            return unet_k1, 1

    pbar = tqdm(total=total_to_save, desc="Generated images")
    while saved < total_to_save:
        bs = min(args.batch_size, total_to_save - saved)

        # 초기 노이즈
        if isinstance(unet_k1.config.sample_size, int):
            image_shape = (bs, unet_k1.config.in_channels,
                           unet_k1.config.sample_size, unet_k1.config.sample_size)
        else:
            image_shape = (bs, unet_k1.config.in_channels, *unet_k1.config.sample_size)
        latents = torch.randn(image_shape, device=device, dtype=dtype)

        # 타임스텝 루프 프로그레스바
        with tqdm(total=len(scheduler.timesteps),
                  desc=f"DDPM steps (batch {saved//bs + 1})",
                  leave=False) as tbar:
            for t in scheduler.timesteps:
                t_int = int(t)
                current_unet, current_k = pick_unet(t_int)

                # MUXUNet이 fusion_factor 필요시 여기 반영
                noise_pred = current_unet(latents, t).sample

                # DDPM 1스텝 업데이트
                step_out = scheduler.step(noise_pred, t, latents)
                latents = step_out.prev_sample

                tbar.update(1)  # 타임스텝 진행 표시

        # 최종 latents 저장
        saved = save_images_tensor(latents, saved, pil_dir)
        latent = latents.detach().cpu()
        all_pt_batches.extend([latents])
        latent = ((latent + 1) * 127.5).clamp(0, 255).to(torch.uint8)
        latent = latent.permute(0, 2, 3, 1).contiguous()
        all_np_batches.extend([latent])

        pbar.update(bs)  # 전체 이미지 생성 수 진행 표시

    pbar.close()


    # pt / npz 저장
    pt_tensor = torch.stack(all_pt_batches)  # [num_batches, B, C, H, W]
    torch.save(pt_tensor, os.path.join(pt_dir, f"{args.n_samples}-samples.pt"))

    imgs = np.concatenate(all_np_batches, axis=0)[:args.n_samples]
    np.savez(os.path.join(npz_dir, f"{args.n_samples}-samples.npz"), imgs)
    print(f"Saved {imgs.shape} images to {npz_dir}")


if __name__ == "__main__":
    main()
