import os
import os.path as osp

import numpy as np
import PIL.Image as PImage
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from models import build_vae_var


class ImageFolderDataset(Dataset):
    """
    Custom dataset for loading images from a directory. Returns the image
    and its original filename.
    """
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = sorted(
            [
                f
                for f in os.listdir(image_dir)
                if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif"))
            ]
        )
        if not self.image_files:
            raise FileNotFoundError(f"No images found in {image_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = osp.join(self.image_dir, self.image_files[idx])
        image = PImage.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, self.image_files[idx]


def setup_environment(gpu_id):
    """Configure environment for the script."""
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device


def load_vae_model(device, use_finetuned):
    """Load the appropriate VAE model based on the flag."""
    if use_finetuned:
        ckpt_path = "./VAR/var_generated/vqvae_finetuned_lpips0_mse0_feat1_epoch150_encoder_only.pth"
        print("Loading FINE-TUNED VAE model.")
    else:
        # Update this path if your original checkpoint is stored elsewhere
        data_dir = "[VAR_MODEL_PATH]"
        ckpt_path = osp.join(data_dir, "vae_ch160v4096z32.pth")
        print("Loading PRE-TRAINED VAE model.")

    if not osp.exists(ckpt_path):
        if use_finetuned:
            raise FileNotFoundError(f"Finetuned model not found at: {ckpt_path}")
        else:
            # Attempt to download the original model
            print("Original VAE checkpoint not found. Attempting to download...")
            hf_home = "https://huggingface.co/FoundationVision/var/resolve/main"
            os.makedirs(data_dir, exist_ok=True)
            os.system(f"wget -P {data_dir} {hf_home}/vae_ch160v4096z32.pth")
            if not osp.exists(ckpt_path):
                 raise FileNotFoundError(f"Download failed. Checkpoint not found at: {ckpt_path}")

    # Build the VAE architecture
    vae, _ = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4, device=device
    )
    vae.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=True)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad_(False)
    print(f"VAE model loaded successfully from: {ckpt_path}")
    return vae


def main():
    """Main function to run the image reconstruction script."""
    # --- Configuration ---
    # Directory of images to reconstruct.
    INPUT_DIR = "./VAR/var_generated/"
    # Directory to save reconstructed images.
    OUTPUT_DIR = "./VAR/reconstructed_var_generated/"
    # Use the fine-tuned VAE model instead of the pre-trained one.
    USE_FINETUNED_MODEL = True
    GPU_ID = 0
    BATCH_SIZE = 16
    IMAGE_SIZE = 256

    device = setup_environment(GPU_ID)
    vae = load_vae_model(device, USE_FINETUNED_MODEL)
    
    # Create output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print(f"Reconstructed images will be saved to: {OUTPUT_DIR}")

    # --- Setup Data Loader ---
    transform = transforms.Compose(
        [
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )
    dataset = ImageFolderDataset(image_dir=INPUT_DIR, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    # --- Reconstruction Loop ---
    vae.to(device)
    for image_batch, filename_batch in tqdm(loader, desc="Reconstructing images"):
        image_batch = image_batch.to(device)

        with torch.no_grad():
            # The `img_to_reconstructed_img` method returns the final image
            # when last_one=True. The output is already clamped to [-1, 1].
            recon_batch = vae.img_to_reconstructed_img(image_batch, last_one=True)

        # Post-process and save each image in the batch
        for i in range(recon_batch.shape[0]):
            recon_img_tensor = recon_batch[i]
            original_filename = filename_batch[i]

            # De-normalize from [-1, 1] to [0, 255]
            recon_img_tensor = (recon_img_tensor + 1) / 2
            recon_img_tensor = (recon_img_tensor.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8)
            
            # Convert to PIL Image and save
            pil_image = PImage.fromarray(recon_img_tensor)
            output_path = osp.join(OUTPUT_DIR, original_filename)
            pil_image.save(output_path)
            
    print(f"\nReconstruction complete for {len(dataset)} images.")


if __name__ == "__main__":
    main() 