import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets
from torchvision.transforms import ToPILImage

from datasets.utils import *

# --------------------------------------------------
# 1. Load a single image from CIFAR-10
# --------------------------------------------------
cifar10_train = datasets.CIFAR10(root="./data", train=True, download=True)

# Grab the first image and its label
image, label = cifar10_train[0]  # PIL Image by default
print("Label of this image:", label)
print("Image size:", image.size)  # Should be (32, 32)

# --------------------------------------------------
# 2. Define the transformations we want to visualize
# --------------------------------------------------
transform_names = [
    "gaussiannoise",
    "shotnoise",
    "impulsenoise",
    "defocusblur",
    "glassblur",
    "motionblur",
    "zoomblur",
    "snow",
    "frost",
    "fog",
    "brightness",
    "contrast",
    "elastic",
    "pixelate",
    "jpegcompression",
    "specklenoise",
    "gaussianblur",
    "spatter",
    "saturate",
]

# We'll import your get_named_version function from wherever you've defined it.
# For now, we assume it's already in the same file or namespace.
# from your_module import get_named_version

# --------------------------------------------------
# 3. Apply each transformation at a chosen severity
# --------------------------------------------------
severity = 3  # Feel free to adjust severity [1..5]
transformed_images = []
for name in transform_names:
    transform_fn = get_named_version(size=32, name=name, strength=severity)
    out_img = transform_fn(image)  # out_img is a PIL.Image
    transformed_images.append((name, out_img))

# --------------------------------------------------
# 4. Visualize results in a grid
# --------------------------------------------------
n_transforms = len(transformed_images)
cols = 5
rows = (n_transforms // cols) + int(n_transforms % cols != 0)

plt.figure(figsize=(15, 3.5 * rows))

# Plot the original first:
plt.subplot(rows, cols, 1)
plt.imshow(np.array(image))
plt.title("original", fontsize=10)
plt.axis("off")

# Then each transformed version:
for i, (t_name, t_img) in enumerate(transformed_images, start=2):
    plt.subplot(rows, cols, i)
    plt.imshow(np.array(t_img))
    plt.title(t_name, fontsize=10)
    plt.axis("off")

plt.tight_layout()
plt.show()
