import torch
import os
import argparse
import h5py
import numpy as np
from torchvision.utils import save_image, make_grid
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from models_catdiff import DiT_models
from download import find_model
from torch.utils.data import DataLoader, Dataset
from glob import glob
from PIL import Image
import torchvision.transforms.functional as TF

# import open_clip
from transformers import CLIPModel, CLIPProcessor
from torchvision.transforms.functional import to_pil_image

# ---------- dataset ---------- #
import os
import glob
from PIL import Image
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset
import re

def natural_sort_key(s):
    """
    Key function for natural sorting of strings containing numbers.
    """
    return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]

class PNGDataset(Dataset):
    def __init__(self, image_dir, map_dir, mask_dir=None):
        # Get all image paths and sort them naturally
        self.image_paths = glob.glob(os.path.join(image_dir, "*.png"))
        self.image_paths.sort(key=natural_sort_key)

        # Explicitly match filenames
        self.map_paths = [os.path.join(map_dir, os.path.basename(p)) for p in self.image_paths]
        self.mask_paths = [os.path.join(mask_dir, os.path.basename(p)) for p in self.image_paths]

        assert all(os.path.exists(p) for p in self.map_paths), "Mismatch detected in map paths!"
        assert all(os.path.exists(p) for p in self.mask_paths), "Mismatch detected in mask paths!"

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = TF.to_tensor(Image.open(self.image_paths[idx]).convert("L"))
        map = TF.to_tensor(Image.open(self.map_paths[idx]).convert("L"))
        mask = TF.to_tensor(Image.open(self.mask_paths[idx]).convert("L"))
        return img, map, mask

def main(args):
    # Setup PyTorch:
    torch.manual_seed(args.seed)
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load model:
    latent_size = args.image_size // 8
    model = DiT_models[args.model](input_size=latent_size, cond_channels=1).to(device)
    
    # Load checkpoint:
    state_dict = find_model(args.ckpt)
    model.load_state_dict(state_dict)
    model.eval()

    diffusion = create_diffusion(str(args.num_sampling_steps))
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

    ds = PNGDataset(args.image_dir, args.map_dir, args.mask_dir)
    dl = DataLoader(ds,
                    batch_size=args.batch,
                    shuffle=False,
                    pin_memory=True)

    # # Load conditional maps and input images:
    # with h5py.File(args.cond_map_path, 'r') as f:
    #     cond_maps = f['maps'][:args.num_samples]  # Load conditional maps
    #     input_images = f['images'][:args.num_samples]  # Load input images
    
    # # Normalize and prepare conditioning maps:
    # cond_maps = torch.tensor(cond_maps, dtype=torch.float32).unsqueeze(1).to(device)
    # if cond_maps.max() > 1.0:
    #     cond_maps /= 255.0  # Normalize to [0, 1]

    # # Normalize and prepare input images:
    # input_images = torch.tensor(input_images, dtype=torch.float32).unsqueeze(1).to(device)
    # if input_images.max() > 1.0:
    #     input_images /= 255.0  # Normalize to [0, 1]

    # Create output directories:
    os.makedirs(f"res_outs/{args.infer_model_name}/output_images_DGM", exist_ok=True)
    os.makedirs(f"res_outs/{args.infer_model_name}/input_images_DGM", exist_ok=True)
    os.makedirs(f"res_outs/{args.infer_model_name}/input_maps_DGM", exist_ok=True)
    os.makedirs(f"res_outs/{args.infer_model_name}/input_masks_DGM", exist_ok=True)
    # os.makedirs("condition_images", exist_ok=True)

    img_counter = 0
    grid_counter = 0
    
    # Create a list to store all paths for verification
    path_records = [] 
    
    for imgs, maps, masks in dl:  # renamed 'map' to 'maps' for clarity
        imgs, maps, masks = imgs.to(device), maps.to(device), masks.to(device)

        # Process maps and masks through VAE
        maps_rgb = maps.repeat(1, 3, 1, 1)
        masks_rgb = masks.repeat(1, 3, 1, 1)
        
        maps_latent = vae.encode(maps_rgb).latent_dist.sample() * 0.18215
        
        # # mask_latent = vae.encode(masks_rgb).latent_dist.sample() * 0.18215
        # cond_map = [to_pil_image(img.squeeze(0)) for img in maps]
        # input_maps = clip_processor(images=cond_map, return_tensors="pt").to(device)
        # cond_map = clip_model.get_image_features(**input_maps)
        # # mask_latent = vae.encode(masks_rgb).latent_dist.sample() * 0.18215
        # #map
        # cond_mask = [to_pil_image(img.squeeze(0)) for img in masks]
        # input_mask = clip_processor(images=cond_mask, return_tensors="pt").to(device)
        # cond_mask = clip_model.get_image_features(**input_mask)\

        cond_map = [to_pil_image(img.squeeze(0)) for img in maps]
        inputs_map = clip_processor(images=cond_map, return_tensors="pt").to(device)

        cond_mask = [to_pil_image(img.squeeze(0)) for img in masks]
        inputs_mask = clip_processor(images=cond_mask, return_tensors="pt").to(device)

        cond_map = clip_model.get_image_features(**inputs_map)
        cond_mask = clip_model.get_image_features(**inputs_mask)

        # vision_outputs_map = clip_model.vision_model(**inputs_map)
        # clip_tokens_map = vision_outputs_map.last_hidden_state  # Shape: [B, num_patches+1, D]
        # cond_map = clip_tokens_map[:, 1:, :]  # Remove CLS token if needed → [B, num_patches, D]

        # vision_outputs_mask = clip_model.vision_model(**inputs_mask)
        # clip_tokens_mask = vision_outputs_mask.last_hidden_state  # Shape: [B, num_patches+1, D]
        # cond_mask = clip_tokens_mask[:, 1:, :]  # Remove CLS token if needed → [B, num_patches, D]


        z = torch.randn(imgs.size(0), 4, latent_size, latent_size, device=device)

        # Generate samples
        samples = diffusion.p_sample_loop(
            model.forward,
            z.shape,
            z,
            clip_denoised=False,
            model_kwargs={'cond_map': cond_map, 'cond_mask': cond_mask },
            progress=True,
            device=device
        )

        samples = vae.decode(samples / 0.18215).sample
        print(f"Min pixel value before mean: {samples.min()}")
        print(f"Max pixel before value: {samples.max()}")
        print(f"Mean before: {samples.mean()}")
        samples = samples.mean(dim=1, keepdim=True)
        # sample_gray = samples.clamp(0, 1)
        # print(f"Min pixel value after mean: {sample_gray.min()}")
        # print(f"Max pixel after value: {sample_gray.max()}")
        # print(f"Mean after: {sample_gray.mean()}")

        print(f"Min pixel value input mean: {imgs.min()}")
        print(f"Max pixel input value: {imgs.max()}")
        print(f"Mean input: {imgs.mean()}")

        # Save individual images
        for k in range(imgs.size(0)):
            global_idx = img_counter + k
            fname = f"{global_idx:04d}.png"
            
            # Record paths for verification
            img_path = os.path.basename(dl.dataset.image_paths[global_idx])
            map_path = os.path.basename(dl.dataset.map_paths[global_idx])
            mask_path = os.path.basename(dl.dataset.mask_paths[global_idx])
            path_records.append((fname, img_path, map_path, mask_path))
            
            # Save all corresponding files
            save_image(samples[k], f"res_outs/{args.infer_model_name}/output_images_DGM/{fname}")
            save_image(imgs[k], f"res_outs/{args.infer_model_name}/input_images_DGM/{fname}", normalize=True)
            save_image(maps[k], f"res_outs/{args.infer_model_name}/input_maps_DGM/{fname}", normalize=True)
            save_image(masks[k], f"res_outs/{args.infer_model_name}/input_masks_DGM/{fname}", normalize=True)

        img_counter += imgs.size(0)
        # grid_n = imgs.size(0)  # Save all images in the batch
        # rows = []   
        # for k in range(grid_n):
        #     rows.append(torch.cat([
        #         samples[k],
        #         imgs[k],
        #         maps[k],
        #         masks[k]
        #     ], dim=1))
        
        # grid = torch.cat(rows, dim=2)
        # save_image(grid, f"output_images_DGM/grid_{grid_counter:04d}.png", normalize=True)
        # grid_counter += 1

    # Save path correspondence for verification
    with open(f"res_outs/{args.infer_model_name}/path_correspondence.txt", "w") as f:
        for record in path_records:
            f.write(f"Generated: {record[0]} | Image: {record[1]} | Map: {record[2]} | Mask: {record[3]}\n")

    print("All images saved successfully with correspondence maintained.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch", type=int,default=8)
    parser.add_argument("--infer_model_name", type=str,choices=["model_a", "model_b", "model_c","model_d","model_e","model_f"], default="model_f")
    parser.add_argument("--cond_type", type=str,choices=["masks", "maps"], default="masks")
    parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-L/4")
    parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")
    parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
    parser.add_argument("--num-sampling-steps", type=int, default=250)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--ckpt", type=str, required=True, help="Path to a DiT checkpoint.")
    parser.add_argument("--cond-map-path", type=str, help="Path to the HDF5 file containing conditioning maps.")
    parser.add_argument("--num-samples", type=int, default=6696, help="Number of samples to generate.")
    parser.add_argument("--image-dir", default="path/to/images/dirs/images", help="Dir with grayscale input PNGs")
    parser.add_argument("--map-dir",   default="path/to/images/dirs/maps", help="Dir with conditioning map PNGs")
    parser.add_argument("--mask-dir",  default="path/to/images/dirs/masks",  help="(Optional) mask PNGs")
    args = parser.parse_args()
    main(args)
