"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""

import argparse
import os

import numpy as np
import torch
import torchvision.utils as vutils
import torch.distributed as dist

from ddbm import dist_util, logger
from ddbm.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
)
from ddbm.karras_diffusion import karras_sample

from datasets import load_data

from pathlib import Path
import torchvision

def main():
    args = create_argparser().parse_args()
    args.use_fp16 = False

    dist_util.setup_dist()
    
    # 获取环境变量中的 sample_dir
    sample_dir = os.environ.get("sample_dir", None)
    logger.configure(dir=str(sample_dir))

    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys()),
    )
    
    # ============================================================
    # 加载权重并处理 Label Embedding 形状不匹配
    # ============================================================
    print(f"正在加载权重: {args.model_path}")
    state_dict = torch.load(args.model_path, map_location="cpu")

    if "label_emb.weight" in state_dict:
        old_emb = state_dict["label_emb.weight"]
        new_emb_shape = model.label_emb.weight.shape
        
        if old_emb.shape != new_emb_shape:
            print(f"检测到 Label Embedding 形状不匹配: {old_emb.shape} vs {new_emb_shape}")
            fixed_emb = model.label_emb.weight.clone().to(old_emb.device)
            fixed_emb[:old_emb.shape[0]] = old_emb
            state_dict["label_emb.weight"] = fixed_emb

    model.load_state_dict(state_dict, strict=False)
    print("权重加载成功！")
    # ============================================================
    
    model = model.to(dist_util.dev())
    if args.use_fp16:
        model = model.half()
    model.eval()

    logger.log("sampling...")

    all_images = []
    all_gt_refs = []  # <--- 【新增】用来存对应的 Ground Truth
    all_labels = []

    args.num_samples = int(os.environ.get("num_samples", None))

    # 保持你原来的设置，不改 deterministic
    all_dataloaders = load_data(
        data_dir=args.data_dir,
        dataset=args.dataset,
        batch_size=args.batch_size,
        image_size=args.image_size,
        deterministic=False,  # <--- 保持你原来的 False
        include_test=(args.split == "test"),
        seed=args.seed,
        num_workers=args.num_workers,
        num_samples=args.num_samples,
    )
    
    if args.split == "train":
        dataloader = all_dataloaders[1]
    elif args.split == "test":
        dataloader = all_dataloaders[2]
    else:
        raise NotImplementedError

    num = 0
    for i, data in enumerate(dataloader):

        x0_image = data[0].to(dist_util.dev())
        y0_image = data[1].to(dist_util.dev())
        
        x0 = x0_image
        y0 = y0_image

        model_kwargs = {"xT": y0}

        if "inpaint" in args.dataset:
            _, mask, label = data[2]
            mask = mask.to(dist_util.dev())
            label = label.to(dist_util.dev())
            model_kwargs["y"] = label
        else:
            mask = None

        indexes = data[2][0].numpy()
        sample, path, nfe, pred_x0, sigmas, _ = karras_sample(
            diffusion,
            model,
            y0,
            x0,
            steps=args.steps,
            mask=mask,
            model_kwargs=model_kwargs,
            device=dist_util.dev(),
            clip_denoised=args.clip_denoised,
            sampler=args.sampler,
            churn_step_ratio=args.churn_step_ratio,
            eta=args.eta,
            order=args.order,
            seed=indexes + args.seed,
        )

        # 保存 debug 图片
        if i == 0 and dist.get_rank() == 0:
            debug_align_path = os.path.join(sample_dir, "debug_align_check.png")
            print(f">>> 🔥 [Debug] 生成对齐检查图: {debug_align_path}")
            ref_vis = (x0[:4].clone().detach().float() + 1) * 0.5
            gen_vis = (sample[:4].clone().detach().float() + 1) * 0.5
            comparison = torch.cat([ref_vis, gen_vis], dim=2) 
            torchvision.utils.save_image(torch.clamp(comparison, 0, 1), debug_align_path, nrow=4)

        if i == 3 and dist.get_rank() == 0:
            debug_dir = os.path.join(sample_dir, "debug_xt")
            os.makedirs(debug_dir, exist_ok=True)
            total_steps = len(path)
            ratios = [0.2, 0.4, 0.6, 0.8, 1.0]
            for ratio in ratios:
                target_idx = max(0, min(int(total_steps * ratio) - 1, total_steps - 1))
                imgs = (path[target_idx][:16] + 1) * 0.5
                grid = torchvision.utils.make_grid(torch.clamp(imgs, 0, 1), nrow=4)
                filename = f"xt_progress_{int(ratio*100)}pct_step_{target_idx+1}.png"
                torchvision.utils.save_image(grid, os.path.join(debug_dir, filename))

        # ============================================================
        # 【核心逻辑】同时处理 Sample 和 X0 (GT)，确保一一对应
        # ============================================================
        
        # 1. 处理 Sample
        sample_uint8 = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
        sample_uint8 = sample_uint8.permute(0, 2, 3, 1).contiguous()
        
        # 2. 处理 GT (x0) - 必须和 Sample 做完全一样的变换
        ref_uint8 = ((x0 + 1) * 127.5).clamp(0, 255).to(torch.uint8)
        ref_uint8 = ref_uint8.permute(0, 2, 3, 1).contiguous()

        # 3. 收集 Sample
        gathered_samples = [torch.zeros_like(sample_uint8) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_samples, sample_uint8)
        gathered_samples = torch.cat(gathered_samples)
        
        # 4. 收集 GT
        gathered_refs = [torch.zeros_like(ref_uint8) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_refs, ref_uint8)
        gathered_refs = torch.cat(gathered_refs)
        
        if "inpaint" in args.dataset:
            gathered_labels = [torch.zeros_like(label) for _ in range(dist.get_world_size())]
            dist.all_gather(gathered_labels, label)
            gathered_labels = torch.cat(gathered_labels)
        
        num += gathered_samples.shape[0]

        # 这里的 save_image 逻辑保持不变
        num_display = min(32, sample_uint8.shape[0])
        if i == 3 and dist.get_rank() == 0:
            vutils.save_image(sample_uint8.permute(0, 3, 1, 2)[:num_display].float() / 255, f"{sample_dir}/sample_{i}.png", nrow=int(np.sqrt(num_display)))
            if x0 is not None:
                vutils.save_image(x0_image[:num_display] / 2 + 0.5, f"{sample_dir}/x_{i}.png", nrow=int(np.sqrt(num_display)))
            vutils.save_image(y0_image[:num_display] / 2 + 0.5, f"{sample_dir}/y_{i}.png", nrow=int(np.sqrt(num_display)))

        all_images.append(gathered_samples.detach().cpu().numpy())
        all_gt_refs.append(gathered_refs.detach().cpu().numpy()) # <--- 存 GT
        if "inpaint" in args.dataset:
            all_labels.append(gathered_labels.detach().cpu().numpy())

        if dist.get_rank() == 0:
            logger.log(f"sampled {num} images")

    # ============================================================
    # 保存两个 NPZ 文件
    # ============================================================
    arr = np.concatenate(all_images, axis=0)[: args.num_samples]
    ref_arr = np.concatenate(all_gt_refs, axis=0)[: args.num_samples] # <--- GT 数组

    if dist.get_rank() == 0:
        shape_str = "x".join([str(x) for x in arr.shape])
        
        # 1. 保存生成的 samples
        out_path = os.path.join(sample_dir, f"samples_{shape_str}_nfe{nfe}.npz")
        np.savez(out_path, arr)
        logger.log(f"Samples saved to {out_path}")

        # 2. 保存对应的 GT (文件名叫 gt_...)
        gt_out_path = os.path.join(sample_dir, f"gt_{shape_str}_nfe{nfe}.npz")
        np.savez(gt_out_path, ref_arr)
        logger.log(f"Ground Truths saved to {gt_out_path}")

        if "inpaint" in args.dataset:
            labels = np.concatenate(all_labels, axis=0)[: args.num_samples]
            np.savez(os.path.join(sample_dir, f"labels_nfe{nfe}.npz"), labels)

    dist.barrier()
    logger.log("sampling complete")

def create_argparser():
    defaults = dict(
        data_dir="",
        dataset="edges2handbags",
        clip_denoised=True,
        num_samples=10000,
        batch_size=32,
        sampler="heun",
        split="train",
        churn_step_ratio=0.0,
        rho=7.0,
        steps=40,
        model_path="",
        exp="",
        seed=42,
        num_workers=8,
        eta=1.0,
        order=1,
    )
    defaults.update(model_and_diffusion_defaults())
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)
    return parser

if __name__ == "__main__":
    main()