import os
import os.path as osp
import random
from re import T

import numpy as np
import PIL.Image as PImage
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from models import build_vae_var
from utils.plot import plot_multi_pdf


class CustomImageDataset(Dataset):
    """Custom dataset for loading images from a flat directory."""

    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = sorted(
            [
                osp.join(image_dir, f)
                for f in os.listdir(image_dir)
                if f.endswith((".png", ".jpg", ".jpeg"))
            ]
        )
        if not self.image_files:
            raise FileNotFoundError(f"No images found in {image_dir}")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = PImage.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, 0  # Return a dummy label


def setup_environment(gpu_id=0, seed=42):
    """Configure environment for the script."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.set_device(gpu_id)
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device


def load_vae_model(device, finetuned=False):
    """Download checkpoint and load VAE model."""
    data_dir = "[VAR_MODEL_PATH]"
    hf_home = "https://huggingface.co/FoundationVision/var/resolve/main"
    vae_ckpt = "vae_ch160v4096z32.pth"
    vae_ckpt_path = osp.join(data_dir, vae_ckpt)
    if finetuned:
        vae_ckpt_path = "[VAR_CKPT_PATH]"
    if not osp.exists(vae_ckpt_path):
        print("Downloading VAE checkpoint...")
        os.system(f"wget -P {data_dir} {hf_home}/{vae_ckpt}")

    # The build_vae_var function returns both VAE and VAR, we only need VAE.
    vae, _ = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4, device=device, num_classes=1000
    )
    print(f"Loading VAE model from: {vae_ckpt_path}")
    vae.load_state_dict(torch.load(vae_ckpt_path, map_location="cpu"), strict=True)
    vae.eval()
    for p in vae.parameters():
        p.requires_grad_(False)
    print("VAE model loaded successfully.")
    return vae


def create_image_loader(data_path, image_size, batch_size, resize=False):
    """Create a DataLoader for images in a directory."""
    if not osp.exists(data_path):
        print(f"Warning: Image directory not found at: {data_path}")
        return None

    transformations = []
    if resize:
        transformations.extend(
            [
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
            ]
        )
    transformations.extend(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )
    transform = transforms.Compose(transformations)

    try:
        dataset = CustomImageDataset(image_dir=data_path, transform=transform)
    except FileNotFoundError as e:
        print(f"Warning: {e}")
        return None

    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
    )
    print(f"Found {len(dataset)} images in {data_path}")
    return loader


def calculate_codebook_loss(vae, images_tensor):
    """
    Calculates the codebook loss (MSE between f and fhat).

    Args:
        vae: The loaded VAE model.
        images_tensor: A batch of images as a tensor.

    Returns:
        The mean MSE loss for the batch.
    """
    with torch.no_grad():
        _, _, f, fhat = vae.img_to_reconstructed_img_with_token_maps(
            images_tensor, last_one=True
        )
        mse_loss = F.mse_loss(fhat, f, reduction="none")
        # Compute the mean loss per image in the batch
        mean_loss_per_image = mse_loss.mean(dim=(1, 2, 3))
        return mean_loss_per_image


def main():
    """Main function to run the script."""
    # Configuration
    GPU_ID = 0
    SEED = 42
    GENERATED_IMAGE_DIR = "[GENERATED_IMAGE_DIR]"
    REAL_IMAGE_DIR = "[REAL_IMAGE_DIR]"
    BATCH_SIZE = 128
    IMAGE_SIZE = 256
    FINETUNED = True
    device = setup_environment(GPU_ID, SEED)
    vae = load_vae_model(device,FINETUNED)

    # Process generated images
    generated_losses = []
    generated_image_loader = create_image_loader(
        GENERATED_IMAGE_DIR, IMAGE_SIZE, BATCH_SIZE, resize=False
    )

    if generated_image_loader:
        for image_batch, _ in tqdm(
            generated_image_loader, desc="Processing generated images"
        ):
            batch_loss = calculate_codebook_loss(vae, image_batch.to(device))
            generated_losses.extend(batch_loss.cpu().numpy())

    if not generated_losses:
        print("No generated images found or processed. Exiting.")
        return

    avg_generated_loss = np.mean(generated_losses)
    print(f"\nCompleted processing {len(generated_losses)} generated images.")
    print(f"Average Generated Codebook MSE Loss: {avg_generated_loss:.6f}")

    # Process real images
    real_losses = []
    real_image_loader = create_image_loader(
        REAL_IMAGE_DIR, IMAGE_SIZE, BATCH_SIZE, resize=True
    )

    if real_image_loader:
        for image_batch, _ in tqdm(
            real_image_loader, desc="Processing real images"
        ):
            batch_loss = calculate_codebook_loss(vae, image_batch.to(device))
            real_losses.extend(batch_loss.cpu().numpy())

        if real_losses:
            avg_real_loss = np.mean(real_losses)
            print(f"\nCompleted processing {len(real_losses)} real images.")
            print(f"Average Real Codebook MSE Loss: {avg_real_loss:.6f}")

            # Plot the distributions
            plot_multi_pdf(
                data_list=[np.array(real_losses), np.array(generated_losses)],
                label_list=["Real Images", "Generated Images"],
                title="Codebook Loss Distribution",
                xlabel="MSE Loss",
                ylabel="PDF",
                save_dir="./VAR/",
            )
        else:
            print("No real images were processed.")
    else:
        print("Skipping real image processing and plotting.")


if __name__ == "__main__":
    main() 