import os
import argparse
import torch
from diffusers import DDPMPipeline, DDPMScheduler, MUXUNet2DModel, UNet2DModel
from torchvision.utils import save_image
from tqdm import tqdm
import numpy as np
import os
import re
import json
from PIL import Image
import torch.nn.functional as F

def find_latest_checkpoint(directory):
    # checkpoint-숫자 패턴 정규식
    pattern = re.compile(r"^checkpoint-(\d+)$")

    max_num = -1
    latest_checkpoint = None

    # 해당 directory 안에 있는 항목들 확인
    for name in os.listdir(directory):
        match = pattern.match(name)
        if match:
            num = int(match.group(1))
            if num > max_num:
                max_num = num
                latest_checkpoint = name

    return latest_checkpoint

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path to the trained model (output_dir)")
    parser.add_argument("--n_samples", type=int, default=64, help="Number of images to generate")
    parser.add_argument("--batch_size", type=int, default=400, help="Batch size for generation")
    parser.add_argument("--output_dir", type=str, default="gen_images", help="Directory to save generated images")
    parser.add_argument("--num_inference_steps", type=int, default=1000, help="Number of DDPM inference steps")
    return parser.parse_args()

def save_images(x, n_saved=0, output_dir=None):
    """
    Save a list of images to the specified directory.
    """
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(0, 2, 3, 1).numpy()
    x = (255 * x).astype(np.uint8)
    images = [Image.fromarray(x[i]) for i in range(x.shape[0])]
    # images = Image.fromarray(x)
    for img in images:
        img.save(os.path.join(output_dir, f"{n_saved:05}.png"))
        n_saved += 1
    return n_saved

@torch.no_grad()
def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # ckpt_path = args.model_path
    # args.model_path = os.path.dirname(args.model_path)


    # Load scheduler and model
    # scheduler_path = os.path.join(os.path.dirname(args.model_path), "scheduler")
    scheduler_path = os.path.join(args.model_path, "scheduler")
    scheduler = DDPMScheduler.from_pretrained(scheduler_path)
    print(scheduler_path)
    #     # 1. config 먼저 로드
    # with open(os.path.join(scheduler_path,"scheduler_config.json"), "r") as f:
    #     config = json.load(f)

    # scheduler = DDPMScheduler.from_config(config)

    #########################################################################################
    ckpt_path = os.path.join(args.model_path, find_latest_checkpoint(args.model_path))
    ###########################################################################################

    # # 2. weights (state_dict) 로드
    # state_dict = torch.load(os.path.join(ckpt_path,"scheduler.bin"))

    # # 3. state_dict 적용
    # scheduler.load_state_dict(state_dict)
    ##########################################################################################333
    unet_ema_path = os.path.join(ckpt_path, "unet_ema")
    # unet_ema_path = os.path.join(args.model_path, "unet_ema")
    ##########################################################################################
    unet = UNet2DModel.from_pretrained(unet_ema_path)
    # unet = MUXUNet2DModel.from_pretrained(unet_ema_path)

    if args.model_path and '/' in args.model_path:
        model_short_name = args.model_path.split('/')[-1]
    else:
        model_short_name = args.model_path
    sample_output_path = os.path.join(args.output_dir, model_short_name)
    sample_output_path = os.path.join(sample_output_path, ckpt_path.split('/')[-1])
    pipeline = DDPMPipeline(unet=unet, scheduler=scheduler).to("cuda" if torch.cuda.is_available() else "cpu")
    nppath = os.path.join(sample_output_path, 'numpz')
    os.makedirs(nppath, exist_ok=True)
    pilpath = os.path.join(sample_output_path, 'pil')
    os.makedirs(pilpath, exist_ok=True)
    ptpath = os.path.join(sample_output_path, 'pt')
    os.makedirs(ptpath, exist_ok=True)
    # Generate in batches
    images = []
    pt = []
    saved_images = 0
    total = args.n_samples
    pbar = tqdm(total=total, desc="Generating images")
    while len(images) < total:
        outputs = pipeline(
            batch_size=args.batch_size,
            num_inference_steps=args.num_inference_steps,
            output_type="tensor",
        ).images  # tensor in [0,1]
        split_images = []
        half = outputs.shape[2] // 2
        for img in outputs:
            i = img.unsqueeze(0)
            for y in (0, half):
                for x in (0, half):
                    tile = i[:, :, y : y + half, x : x + half]
                    up = F.interpolate(tile, size=(half*2, half*2), mode="bilinear", align_corners=False)
                    split_images.append(up.squeeze(0))
        outputs = torch.stack(split_images, dim=0)
        if outputs.shape[0] != args.batch_size*4:
            assert outputs.shape[0] != args.batch_size*4, f"Unexpected output shape: {outputs.shape}"
        saved_images = save_images(outputs, n_saved=saved_images, output_dir=pilpath)
        output = outputs.detach().cpu()
        pt.extend([outputs])
        output = ((output + 1) * 127.5).clamp(0, 255).to(torch.uint8)
        output = output.permute(0, 2, 3, 1)
        output = output.contiguous()
        images.extend([output])
        # saved_images += outputs.shape[0]
        pbar.update(outputs.shape[0])
        if saved_images >= total:
            break
    img = np.concatenate(images, axis=0)
    img = img[:total]
    pbar.close()
        # Stack all tensors into one tensor and save
    pt_tensor = torch.stack(pt)
    torch.save(pt_tensor, os.path.join(ptpath, f"{args.n_samples}-samples.pt"))


    # if args.model_path and '/' in args.model_path:
    #     model_short_name = args.model_path.split('/')[-2]
    # else:
    #     model_short_name = args.model_path
    # nppath = os.path.join(nppath, f"{model_short_name}-{args.n_samples}-samples.npz")
    nppath = os.path.join(nppath, f"{args.n_samples}-samples.npz")
    print(f"Saving generated images shape {img.shape} to {nppath}")
    np.savez(nppath, img)
    # # Generate in batches
    # images = []
    # total = args.n_samples
    # pbar = tqdm(total=total, desc="Generating images")
    # while total > 0:
    #     bs = min(args.batch_size, total)
    #     # generator = torch.Generator(device=pipeline.device)
    #     outputs = pipeline(
    #         batch_size=bs,
    #         # generator=generator,
    #         num_inference_steps=args.num_inference_steps,
    #         output_type="pil",
    #     ).images  # tensor in [0,1]
    #     images.extend(outputs)
    #     total -= bs
    #     pbar.update(bs)
    # pbar.close()
    # os.makedirs(os.path.join(args.output_dir, args.model_path.split('/')[-2]), exist_ok=True)
    # # Save individual PIL images
    # for idx, img in enumerate(images):
    #     img.save(os.path.join(args.output_dir, args.model_path.split('/')[-2], f"{idx:05}.png"))

    # print(f"Saved {len(images)} images to {args.output_dir}")


if __name__ == "__main__":
    main()
