import os
import os.path as osp
import random

import numpy as np
import PIL.Image as PImage
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from models import build_vae_var
from utils.plot import plot_multi_pdf


class CustomImageDataset(Dataset):
    """Custom dataset for loading images from a flat directory."""

    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = sorted(
            [
                osp.join(image_dir, f)
                for f in os.listdir(image_dir)
                if f.endswith((".png", ".jpg", ".jpeg"))
            ]
        )
        if not self.image_files:
            raise FileNotFoundError(f"No images found in {image_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = PImage.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, 0  # Return a dummy label


def setup_environment(gpu_id=0, seed=42):
    """Configure environment for the script."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.set_device(gpu_id)
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device


def load_vae_model(device, finetuned=False):
    """Download checkpoint and load VAE model."""
    data_dir = "[VAR_MODEL_PATH]"
    hf_home = "https://huggingface.co/FoundationVision/var/resolve/main"
    vae_ckpt = "vae_ch160v4096z32.pth"
    vae_ckpt_path = osp.join(data_dir, vae_ckpt)
    if finetuned:
        vae_ckpt_path = "./VAR/var_generated/vqvae_finetuned.pth"
    if not osp.exists(vae_ckpt_path):
        print("Downloading VAE checkpoint...")
        os.system(f"wget -P {data_dir} {hf_home}/{vae_ckpt}")

    # The build_vae_var function returns both VAE and VAR, we only need VAE.
    vae, _ = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4, device=device, num_classes=1000
    )
    vae.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu"), strict=True)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad_(False)
    print("VAE model loaded successfully.")
    return vae


def create_image_loader(data_path, image_size, batch_size, resize=False):
    """Create a DataLoader for images in a directory."""
    if not osp.exists(data_path):
        print(f"Warning: Image directory not found at: {data_path}")
        return None

    transformations = []
    if resize:
        transformations.extend(
            [
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
            ]
        )
    transformations.extend(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )
    transform = transforms.Compose(transformations)

    try:
        dataset = CustomImageDataset(image_dir=data_path, transform=transform)
    except FileNotFoundError as e:
        print(f"Warning: {e}")
        return None

    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
    )
    print(f"Found {len(dataset)} images in {data_path}")
    return loader


def calculate_codebook_loss_per_scale(vae, images_tensor):
    """
    Calculates the codebook loss (MSE between f and fhat) for each scale.

    Args:
        vae: The loaded VAE model.
        images_tensor: A batch of images as a tensor.

    Returns:
        A list of mean MSE loss tensors for the batch, one for each scale.
    """
    with torch.no_grad():
        _, _, f, fhat_list = vae.img_to_reconstructed_img_with_token_maps(
            images_tensor, last_one=False
        )
        losses_per_scale = []
        for i in range(len(fhat_list)):
            fhat_i = fhat_list[i]
            print(f"fhat_i.shape: {fhat_i.shape}, f.shape: {f.shape}")
            mse_loss = F.mse_loss(fhat_i, f, reduction="none")
            mean_loss_per_image = mse_loss.mean(dim=(1, 2, 3))
            losses_per_scale.append(mean_loss_per_image)
        return losses_per_scale



def main():
    """Main function to run the script."""
    # Configuration
    GPU_ID = 0
    SEED = 42
    GENERATED_IMAGE_DIR = "./VAR/var_generated/"
    FINETUNED = True
    REAL_IMAGE_DIR = "[REAL_IMAGE_DIR]"
    BATCH_SIZE = 128  
    IMAGE_SIZE = 256

    device = setup_environment(GPU_ID, SEED)
    vae = load_vae_model(device)

    # Process generated images
    generated_losses_per_scale = []
    generated_image_loader = create_image_loader(
        GENERATED_IMAGE_DIR, IMAGE_SIZE, BATCH_SIZE, resize=False
    )

    if generated_image_loader:
        first_batch = True
        for image_batch, _ in tqdm(
            generated_image_loader, desc="Processing generated images"
        ):
            batch_losses_per_scale = calculate_codebook_loss_per_scale(
                vae, image_batch.to(device)
            )
            if first_batch:
                num_scales = len(batch_losses_per_scale)
                generated_losses_per_scale = [[] for _ in range(num_scales)]
                first_batch = False

            for i in range(len(batch_losses_per_scale)):
                generated_losses_per_scale[i].extend(
                    batch_losses_per_scale[i].cpu().numpy()
                )

    if not generated_losses_per_scale:
        print("No generated images found or processed. Exiting.")
        return

    total_generated_images = len(generated_losses_per_scale[0])
    print(f"\nCompleted processing {total_generated_images} generated images.")
    for i, losses in enumerate(generated_losses_per_scale):
        avg_loss = np.mean(losses)
        print(f"  Scale {i+1} - Avg Loss: {avg_loss:.6f}")

    # Process real images
    real_losses_per_scale = []
    real_image_loader = create_image_loader(
        REAL_IMAGE_DIR, IMAGE_SIZE, BATCH_SIZE, resize=True
    )

    if real_image_loader:
        first_batch = True
        for image_batch, _ in tqdm(real_image_loader, desc="Processing real images"):
            batch_losses_per_scale = calculate_codebook_loss_per_scale(
                vae, image_batch.to(device)
            )
            if first_batch:
                num_scales = len(batch_losses_per_scale)
                real_losses_per_scale = [[] for _ in range(num_scales)]
                first_batch = False

            for i in range(len(batch_losses_per_scale)):
                real_losses_per_scale[i].extend(
                    batch_losses_per_scale[i].cpu().numpy()
                )

        if real_losses_per_scale:
            total_real_images = len(real_losses_per_scale[0])
            print(f"\nCompleted processing {total_real_images} real images.")
            for i, losses in enumerate(real_losses_per_scale):
                avg_loss = np.mean(losses)
                print(f"  Scale {i+1} - Avg Loss: {avg_loss:.6f}")

            # Plot the distributions for each scale
            num_scales = len(real_losses_per_scale)
            for i in range(num_scales):
                plot_multi_pdf(
                    data_list=[
                        np.array(real_losses_per_scale[i]),
                        np.array(generated_losses_per_scale[i]),
                    ],
                    label_list=["Real Images", "Generated Images"],
                    title=f"Codebook Loss Distribution (Scale {i+1})",
                    xlabel="MSE Loss",
                    ylabel="PDF",
                    save_dir="./VAR/",
                )
                print(f"Plotted distribution for scale {i+1}.")
        else:
            print("No real images were processed.")
    else:
        print("Skipping real image processing and plotting.")


if __name__ == "__main__":
    main() 