import argparse
import itertools
import json
import os
import random
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from config import get_dataset_config, get_unet_config
from denoising_pyramid import *
from diffusers import DDIMScheduler
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid, save_image
from tqdm.auto import tqdm

import wandb
from nn_baselines.src.diffusion_utils import GaussianDiffusionSampler
from nn_baselines.src.training_utils import load_model, sample_images
from data import get_dataset_loader

N_BATCH = 6


def get_dataset_config(dataset_name: str) -> dict:
    """Get dataset-specific configuration."""
    configs = {
        "mnist": {
            "img_size": 28,
            "in_channels": 1,
            "out_channels": 1,
            "kernel_size_schedule": [
                28,
                28,
                23,
                17,
                17,
                13,
                9,
                9,
            ],  # Default for MNIST
        },
        "fashion_mnist": {
            "img_size": 28,
            "in_channels": 1,
            "out_channels": 1,
            "kernel_size_schedule": [
                28,
                28,
                23,
                17,
                17,
                13,
                9,
                9
            ],  # Same as MNIST initially
        },
        "cifar10": {
            "img_size": 32,
            "in_channels": 3,
            "out_channels": 3,
            "kernel_size_schedule": [
                32,
                32,
                32,
                29,
                25,
                17,
                13,
                9,
                7,
                3,
            ],  # Larger kernels for CIFAR10
        },
        "ffhq": {
            "img_size": 64,
            "in_channels": 3,
            "out_channels": 3,
            "kernel_size_schedule": [
                64,
                45,
                33,
                25,
                17,
                9,
                5,
                3,
            ],  # Same as CIFAR10 initially
        },
        "celeba_hq": {
            "img_size": 64,
            "in_channels": 3,
            "out_channels": 3,
            "kernel_size_schedule": [
                64,
                64,
                64,
                45,
                45,
                25,
                17,
                17,
                9,
                3,
            ],  # Adjusted for 64x64
        },
        "afhq": {
            "img_size": 64,
            "in_channels": 3,
            "out_channels": 3,
            "kernel_size_schedule": [
                64,
                64,
                45,
                33,
                25,
                17,
                17,
                9,
                9,
                3,
            ],  # Same as CelebA-HQ initially
        },
    }
    return configs[dataset_name]


def get_unet_config(dataset_name: str, num_images: int) -> Dict:
    """Returns UNet configuration based on dataset and size."""
    # Template for model paths - to be replaced with actual paths
    model_paths = {
        "mnist": {
            -1: (
                "trained_models/unet/unet_mnist_-1_noattn_20250513_194551",
                "ckpt_epoch_200.pt",
            ),
            # unet_mnist_-1_noattn_20250513_195201 200
        },
        "fashion_mnist": {
            -1: (
                "trained_models/unet/unet_fashion_mnist_-1_noattn_20250513_194633",
                "ckpt_epoch_200.pt",
            ),
            # unet_fashion_mnist_-1_noattn_20250514_001525 200
        },
        "cifar10": {
            -1: (
                "trained_models/unet/unet_cifar10_-1_noattn_20250313_232926",
                "ckpt_epoch_200.pt",
            ),
            100: (
                "trained_models/unet/unet_cifar10_100_noattn_20250313_232926",
                "ckpt_epoch_70000.pt",
            ),
            1000: (
                "trained_models/unet/unet_cifar10_1000_noattn_20250313_232925",
                "ckpt_epoch_10000.pt",
            ),
            10000: (
                "trained_models/unet/unet_cifar10_10000_noattn_20250312_035606",
                "ckpt_epoch_1000.pt",
            ),
            # unet_cifar10_-1_noattn_20250512_160306 200
        },
        "ffhq": {
            -1: ("trained_models/unet/unet_ffhq_-1_noattn", "ckpt_epoch_200.pt"),
            # broken ?
        },
        "celeba_hq": {
            -1: (
                "trained_models/unet/unet_celeba_hq_-1_noattn_20250514_030749",
                "ckpt_epoch_200.pt",
            ),
            # unet_celeba_hq_-1_noattn_20250514_030841 200
        },
        "afhq": {
            -1: (
                "trained_models/unet/unet_afhq_-1_noattn_20250514_002002",  # broken: unet_afhq_-1_noattn_20250513_195651
                "ckpt_epoch_200.pt",
            ),
            # unet_afhq_-1_noattn_20250515_004233 200
        },
    }

    # Get the closest dataset size (use exact match or next largest size)
    available_sizes = sorted(model_paths[dataset_name].keys())
    if num_images == -1:
        model_size = -1  # Use full dataset model
    else:
        model_size = next(
            (size for size in available_sizes if size >= num_images),
            available_sizes[-1],
        )
    print(
        f"Using trained UNET with dataset size: {'full' if model_size == -1 else model_size}"
    )

    # Get dataset-specific configuration
    dataset_config = get_dataset_config(dataset_name)
    img_size = dataset_config["img_size"]

    # Get model architecture based on image size
    if img_size == 28:  # MNIST, FashionMNIST
        channel = 64
        channel_mult = [1, 2, 2]  # Only 3 downsamples: 28->14->7->3
    elif img_size == 32:  # CIFAR10, FFHQ
        channel = 128
        channel_mult = [1, 2, 3, 4]  # 32->16->8->4->2
    elif img_size == 64:  # CelebA-HQ, AFHQ
        channel = 128
        channel_mult = [1, 2, 3, 4]  # 64->32->16->8->4
    else:
        raise ValueError(f"Unsupported image size: {img_size}")

    return {
        "epoch": 200,
        "batch_size": 32,
        "T": 1000,
        "channel": channel,
        "random_seed": 42,
        "eval_random_seed": 42,
        "subset_size": model_size,
        "channel_mult": channel_mult,
        "attn": [],
        "num_res_blocks": 2,
        "dropout": 0.15,
        "lr": 1e-4,
        "multiplier": 2.0,
        "beta_1": 1e-4,
        "beta_T": 0.02,
        "img_size": img_size,
        "grad_clip": 1.0,
        "device": "cuda",
        "dataset_root": "data/",
        "checkpoint_freq": 20,
        "use_wandb": True,
        "sample_freq": 20,
        "training_load_weight": model_paths[dataset_name][model_size][1],
        "save_weight_dir": model_paths[dataset_name][model_size][0],
        "test_load_weight": model_paths[dataset_name][model_size][1],
        "sampled_dir": "./SampledImgs/",
        "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
        "sampledImgName": "SampledNoGuidenceImgs.png",
        "nrow": 8,
        "model_type": "unet",
        "in_channels": dataset_config["in_channels"],
        "out_channels": dataset_config["out_channels"],
        "dataset_name": dataset_name,
    }


@torch.no_grad()
def sample_from_unet(
    noise_sample: torch.Tensor,
    num_images: int,
    device: str = "cuda",
    num_steps: int = 1000,
    dataset_name: str = "cifar10",
) -> torch.Tensor:
    """Generate samples from a trained UNet model using the same DDIM scheduler."""
    # Initialize model and load weights
    config = get_unet_config(dataset_name, num_images)
    model = load_model(config, device)
    model.eval()

    # Initialize DDIM scheduler with same parameters as denoising pyramid
    scheduler = DDIMScheduler(
        beta_start=config["beta_1"],
        beta_end=config["beta_T"],
        beta_schedule="linear",
        prediction_type="epsilon",
    )
    scheduler.set_timesteps(num_steps)  # Match number of steps

    # Initialize trajectory lists
    trajectory_noisy = []
    trajectory_x0 = []
    trajectory_eps = []
    # Start from same noise and scale by initial noise sigma
    cur_img = noise_sample * scheduler.init_noise_sigma

    # Sample using DDIM steps
    for timestep in scheduler.timesteps:
        # Get model prediction
        model_output = model(cur_img, timestep.to(device)[None])

        # Step the scheduler
        step_output = scheduler.step(
            model_output=model_output,
            timestep=timestep,
            sample=cur_img,
            generator=None,
        )
        pred_x0 = step_output.pred_original_sample

        # Store trajectory
        trajectory_noisy.append(cur_img.clone())
        trajectory_x0.append(pred_x0.clone())
        trajectory_eps.append(model_output.clone())

        # Update current image
        cur_img = step_output.prev_sample

    return cur_img, trajectory_noisy, trajectory_x0, trajectory_eps, scheduler.timesteps


@torch.no_grad()
def generate_unet_single_x0_preds(
    trajectory_noisy: List[torch.Tensor],
    num_images: int,
    device: str = "cuda",
    num_steps: int = 1000,
    dataset_name: str = "cifar10",
) -> List[torch.Tensor]:
    """Generate x0 predictions from UNet for each noisy image in the trajectory."""
    # Initialize model and load weights
    config = get_unet_config(dataset_name, num_images)
    model = load_model(config, device)
    model.eval()

    # Initialize DDIM scheduler with same parameters as denoising pyramid
    scheduler = DDIMScheduler(
        beta_start=config["beta_1"],
        beta_end=config["beta_T"],
        beta_schedule="linear",
        prediction_type="epsilon",
    )
    scheduler.set_timesteps(num_steps)

    # Generate x0 predictions for each noisy image
    x0_predictions = []
    eps_predictions = []
    for t, noisy_img in zip(scheduler.timesteps, trajectory_noisy):
        # Get model prediction
        model_output = model(noisy_img, t.to(device)[None])

        # Step the scheduler to get x0 prediction
        step_output = scheduler.step(
            model_output=model_output,
            timestep=t,
            sample=noisy_img,
            generator=None,
        )
        pred_x0 = step_output.pred_original_sample
        x0_predictions.append(pred_x0.clone())
        eps_predictions.append(model_output.clone())
    return x0_predictions, eps_predictions


def find_closest_images(
    samples: torch.Tensor, dataloader: DataLoader, config: Dict
) -> torch.Tensor:
    # Find closest samples in the dataset
    ref_imgs = []
    for batch in dataloader:
        if isinstance(batch, (list, tuple)):
            batch = batch[0]  # For CIFAR10 which returns (image, label)
        ref_imgs.append(batch)

    ref_imgs = torch.cat(ref_imgs, dim=0)

    # Ensure correct device and shape
    ref_imgs = ref_imgs.to(samples.device)
    if ref_imgs.shape[-1] != config.resolution:
        ref_imgs = torch.nn.functional.interpolate(
            ref_imgs,
            size=(config.resolution, config.resolution),
            mode="bilinear",
            align_corners=False,
        )

    # Calculate distances and find closest images
    distances = torch.mean(
        (samples.unsqueeze(0) - ref_imgs.unsqueeze(1)) ** 2,
        dim=(2, 3, 4),
    )
    closest_indices = torch.argmin(distances, dim=0)
    closest_imgs = ref_imgs[closest_indices]

    return closest_imgs


def normalize(images: torch.Tensor) -> torch.Tensor:
    """Normalize images to [0, 1] range for visualization."""
    return images
    # return (images - images.min()) / (images.max() - images.min())


def calculate_r2_score(x: torch.Tensor, y: torch.Tensor) -> float:
    """Calculate R² score between two sets of images."""
    x_flat = x.reshape(x.size(0), -1)
    y_flat = y.reshape(y.size(0), -1)

    # Calculate R² score for each sample pair
    var_y = torch.var(y_flat, dim=1)
    ss_res = torch.sum((x_flat - y_flat) ** 2, dim=1)
    r2 = 1 - (ss_res / (var_y * x_flat.size(1)))

    return 1 - r2.mean().item()


def calculate_mse(x: torch.Tensor, y: torch.Tensor) -> float:
    """Calculate Mean Squared Error between two sets of images."""
    x_flat = x.reshape(x.size(0), -1)
    y_flat = y.reshape(y.size(0), -1)

    # Calculate MSE for each sample pair and average
    mse = torch.mean((x_flat - y_flat) ** 2, dim=1)
    return mse.mean().item()


def calculate_l2_distances(samples: torch.Tensor, closest_imgs: torch.Tensor) -> float:
    """Calculate average L2 distance between samples and their closest dataset images."""
    distances = torch.mean((samples - closest_imgs) ** 2, dim=(1, 2, 3))
    return distances.mean().item()


@torch.no_grad()
def run_experiment(
    config: Dict,
    dataloader: DataLoader,
    num_images: int,
    tags: List[str] = [],
    seed: int = 42,
):
    """Run ablation experiment with specified configuration."""
    # Set random seed for reproducibility
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Create timestamp for this run
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Initialize wandb
    project_name = f"denoising-pyramid-ablation-{config.get('dataset_name', 'default')}"

    # Get model type from tags
    model_type = next((tag for tag in tags), "unknown")
    run_name = f"{model_type}_{config.get('dataset_name')}_seed{seed}_{config.get('save_prefix', '').split('_')[-1]}"

    if "ours" in tags:
        run_name += f"_pow_{config.get('mask_threshold', 0.05)}"
    
    # Create local save directory
    save_dir = Path("ablation") / f"{run_name}_{timestamp}"
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # Create subdirectories for images
    grid_dir = save_dir
    individual_dir = save_dir / "individual_images"
    grid_dir.mkdir(exist_ok=True)
    individual_dir.mkdir(exist_ok=True)
    
    # Save configuration
    with open(save_dir / "config.txt", "w") as f:
        for key, value in config.items():
            f.write(f"{key}: {value}\n")

    with wandb.init(project=project_name, name=run_name, config=config, tags=tags):
        wandb.run.log_code(".")

        config = wandb.config

        # Initialize appropriate model based on tags
        if "wiener" in tags:
            model = DenoisingWiener(**config)
        elif "kamb" in tags:
            model = DenoisingKamb(**config)
        elif "ours" in tags:
            model = KambWithWienerBasedPatches(**config)
        elif "optimal" in tags:
            model = DenoisingKamb(**config)  # The only difference --
        elif "unet" in tags:
            model = AnotherUnet(**config)
        elif "niedoba" in tags:
            model = DenoisingNiedoba(**config)
        else:
            raise ValueError("Unknown model type in tags")

        # Train model using dataloader
        model.train(dataloader)

        # Generate samples with timing
        start_time = time.time()
        (
            _,
            noise_sample,
            trajectory_noisy,
            trajectory_eps,
            trajectory_x0,
            trajectory_level_imgs,
            trajectory_level_x0s,
        ) = model.sample(
            batch_size=N_BATCH,
            return_trajectory=True,
            seed=seed,
        )
        sample_time = time.time() - start_time

        # Generate samples from UNet using the same noise
        (
            _,
            unet_trajectory_noisy,
            unet_trajectory_x0,
            unet_trajectory_eps,
            unet_timesteps,
        ) = sample_from_unet(
            noise_sample,
            num_images,
            config.device,
            config["num_steps"],
            dataset_name=config["dataset_name"],
        )

        # Generate single-step x0 predictions for each noisy image in the trajectory
        unet_single_x0_preds, unet_single_eps_preds = generate_unet_single_x0_preds(
            trajectory_noisy,
            num_images,
            config.device,
            config["num_steps"],
            dataset_name=config["dataset_name"],
        )

        # Initialize metrics dictionary
        metrics = {
            "timestep": [],
            "r2_score_vs_unet": [],
            "r2_score_vs_unet_single": [],
            "r2_score_vs_unet_single_eps": [],
            "l2_distance_to_dataset": [],
            "mse_trajectory_vs_unet": [],
            "mse_trajectory_vs_unet_single": [],
            "mse_trajectory_vs_unet_single_eps": [],
        }

        # Log trajectories
        for i in tqdm(range(len(trajectory_noisy)), desc="Logging trajectories"):
            # Scale t to match UNet's 1000 steps, ensuring last step maps to 999
            t = (
                1000 - unet_timesteps[i].item()
            )  # min(int(i * 1000 / (len(trajectory_noisy) - 1)), 999)

            # Find closest images in the trajectories
            closest_imgs_x0 = find_closest_images(trajectory_x0[i], dataloader, config)
            closest_imgs_unet_x0 = find_closest_images(
                unet_trajectory_x0[i], dataloader, config
            )

            # Save individual images
            image_sets = {
                "closest_real_ours": closest_imgs_x0,
                "ours_x0": trajectory_x0[i],
                "unet_single_x0": unet_single_x0_preds[i],
                "unet_trajectory_x0": unet_trajectory_x0[i],
                "closest_real_unet": closest_imgs_unet_x0,
            }
            
            for name, images in image_sets.items():
                for j, img in enumerate(images):
                    save_image(
                        (img.cpu() + 1) / 2,
                        individual_dir / f"step_{t:04d}_{name}_sample_{j:02d}.png",
                    )

            # Create and save grid image
            x0_grid = make_grid(
                torch.cat(
                    [
                        normalize(closest_imgs_x0),
                        normalize(trajectory_x0[i]),
                        normalize(unet_single_x0_preds[i]),
                        normalize(unet_trajectory_x0[i]),
                        normalize(closest_imgs_unet_x0),
                    ],
                    dim=0,
                ),
                nrow=N_BATCH,
                padding=2,
            )
            save_image(
                (x0_grid.cpu() + 1) / 2, grid_dir / f"trajectory_step_{t:04d}.png"
            )

            # Calculate metrics
            current_metrics = {
                "timestep": t,
                "r2_score_vs_unet": calculate_r2_score(
                    trajectory_x0[i], unet_trajectory_x0[i]
                ),
                "r2_score_vs_unet_single": calculate_r2_score(
                    trajectory_x0[i], unet_single_x0_preds[i]
                ),
                "r2_score_vs_unet_single_eps": calculate_r2_score(
                    trajectory_eps[i], unet_single_eps_preds[i]
                ),
                "l2_distance_to_dataset": calculate_l2_distances(
                    trajectory_x0[i], closest_imgs_x0
                ),
                "mse_trajectory_vs_unet": calculate_mse(
                    trajectory_x0[i], unet_trajectory_x0[i]
                ),
                "mse_trajectory_vs_unet_single": calculate_mse(
                    trajectory_x0[i], unet_single_x0_preds[i]
                ),
                "mse_trajectory_vs_unet_single_eps": calculate_mse(
                    trajectory_eps[i], unet_single_eps_preds[i]
                ),
            }

            # Update metrics dictionary
            for key, value in current_metrics.items():
                metrics[key].append(value)

            # Log to wandb
            wandb.log(
                {
                    "trajectory_x0": wandb.Image(
                        (x0_grid.cpu() + 1) / 2,
                        caption=f"Top to bottom: Closest real (ours), Ours x0, UNet single x0, UNet trajectory x0, Closest real (UNet)",
                    ),
                    **{k: v for k, v in current_metrics.items() if k != "timestep"},
                    "current_timestep": t,
                },
                step=t,
            )

        # Prepare metrics data for JSON
        metrics_data = {"per_timestep": [], "summary": {}, "runtime": {}}

        # Add per-timestep metrics
        for i in range(len(metrics["timestep"])):
            timestep_data = {
                "timestep": int(metrics["timestep"][i]),
                **{k: float(v[i]) for k, v in metrics.items() if k != "timestep"},
            }
            metrics_data["per_timestep"].append(timestep_data)

        # Add summary statistics
        for metric_name in metrics.keys():
            if metric_name != "timestep":
                values = metrics[metric_name]
                metrics_data["summary"][metric_name] = {
                    "mean": float(np.mean(values)),
                    "std": float(np.std(values)),
                    "min": float(np.min(values)),
                    "max": float(np.max(values)),
                }

        # Add runtime information
        metrics_data["runtime"] = {
            "total_sample_time": float(sample_time),
            "average_time_per_step": float(sample_time / len(trajectory_noisy)),
            "steps_per_second": float(len(trajectory_noisy) / sample_time),
        }

        # Save metrics to JSON
        with open(save_dir / "metrics.json", "w") as f:
            json.dump(metrics_data, f, indent=2)


def get_analytic_config(
    dataset_name: str, num_images: int = -1, custom_kernel_schedule: List[int] = None
) -> Dict:
    """Returns configuration for analytic models based on dataset."""
    dataset_config = get_dataset_config(dataset_name)

    # Use custom kernel schedule if provided, otherwise use dataset-specific schedule
    kernel_size = (
        custom_kernel_schedule
        if custom_kernel_schedule is not None
        else dataset_config["kernel_size_schedule"]
    )

    # Handle -1 as "full" dataset
    save_prefix = f"{dataset_name}_{'full' if num_images == -1 else str(num_images)}"

    return {
        "resolution": dataset_config["img_size"],
        "device": "cuda",
        "denoiser": "knn",
        "temperature": 1.0,
        "latent_diffusion": False,  # We'll handle latent diffusion separately if needed
        "stride": 2,
        "sigma_correction": True,
        "level_mixture_alpha": 1.0,
        "random_padding": False,
        "kernel_size": kernel_size,
        "kernel_overlap": 0.75,
        "stride_gen": [1] * len(kernel_size),
        "num_steps": 10,
        "denoiser_args": {"num_neighbors": 200},
        "fill_in_zeros_in_x0": True,
        "embed_w": -1.0,
        "aggregation_mode": "mean",
        "save_dir": "trained_models",
        "dataset_name": dataset_name,
        "save_prefix": save_prefix,
        "in_channels": dataset_config["in_channels"],
        "out_channels": dataset_config["out_channels"],
    }


def get_wiener_config(dataset_name: str, num_images: int = -1) -> Dict:
    """Returns configuration for Wiener model."""
    config = get_analytic_config(dataset_name, num_images)
    # For Wiener model, we use a single kernel size equal to the image size
    config["kernel_size"] = [config["resolution"]]
    return config


def get_kamb_config(
    dataset_name: str, num_images: int = -1, custom_kernel_schedule: List[int] = None
) -> Dict:
    """Returns configuration for Kamb model."""
    config = get_analytic_config(dataset_name, num_images, custom_kernel_schedule)
    # Kamb uses the same kernel size schedule as the base config
    return config


def get_wiener_based_patches_config(dataset_name: str, num_images: int = -1) -> Dict:
    """Returns configuration for KambWithWienerBasedPatches model."""
    config = get_analytic_config(dataset_name, num_images)
    # This model uses the same kernel size schedule as the base config
    return config


def get_optimal_config(dataset_name: str, num_images: int = -1) -> Dict:
    """Returns configuration for Optimal model."""
    config = get_analytic_config(dataset_name, num_images)
    # This model uses the same kernel size schedule as the base config
    config["kernel_size"] = [config["resolution"]]
    return config


def get_unet_baseline_config(dataset_name: str, num_images: int = -1) -> Dict:
    """Returns configuration for UNet model."""
    config = get_analytic_config(dataset_name, num_images)
    # This model uses the same kernel size schedule as the base config
    return config


def get_niedoba_config(dataset_name: str, num_images: int = -1) -> Dict:
    """Returns configuration for Niedoba model."""
    config = get_analytic_config(dataset_name, num_images)
    # This model uses the same kernel size schedule as the base config
    return config


def main():
    parser = argparse.ArgumentParser(description="Run ablation experiments")
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["mnist", "fashion_mnist", "cifar10", "ffhq", "celeba_hq", "afhq"],
        default="cifar10",
        help="Dataset to use for experiments",
    )
    parser.add_argument(
        "--num_images",
        type=int,
        default=-1,
        help="Number of images to use (-1 = use all images)",
    )
    parser.add_argument(
        "--mask_threshold",
        type=float,
        default=0.02,
        help="Mask threshold for Wiener-based patches",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed (default: 42)",
    )
    parser.add_argument(
        "--model",
        type=str,
        choices=["wiener", "kamb", "ours", "optimal", "unet", "niedoba"],
        required=True,
        help="Model type to run",
    )
    args = parser.parse_args()

    # Get appropriate configuration based on model type
    if args.model == "wiener":
        config = get_wiener_config(args.dataset, args.num_images)
        tags = ["wiener"]
    elif args.model == "kamb":
        config = get_kamb_config(args.dataset, args.num_images)
        tags = ["kamb"]
    elif args.model == "ours":
        config = get_wiener_based_patches_config(args.dataset, args.num_images)
        config["mask_threshold"] = args.mask_threshold
        tags = ["ours"]
    elif args.model == "optimal":
        config = get_optimal_config(args.dataset, args.num_images)
        tags = ["optimal"]
    elif args.model == "unet":
        config = get_unet_baseline_config(args.dataset, args.num_images)
        tags = ["unet"]
    elif args.model == "niedoba":
        config = get_niedoba_config(args.dataset, args.num_images)
        tags = ["niedoba"]

    # Get dataloader using unified loader
    dataloader = get_dataset_loader(
        dataset_name=args.dataset,
        num_images=args.num_images,
        batch_size=1024,
    )

    batch = next(iter(dataloader))
    print(f"\n\nDataset: {args.dataset}")
    print(f"Shape: {batch[0].shape}")
    print(f"Min: {batch[0].min()}, Max: {batch[0].max()}")
    print(f"Mean: {batch[0].mean()}, Std: {batch[0].std()}")
    print("\n\n")

    # Run experiment
    run_experiment(config, dataloader, args.num_images, tags=tags, seed=args.seed)


if __name__ == "__main__":
    main()
