"""Configuration utilities for the analytic models."""
from typing import Dict


def get_dataset_config(dataset_name: str) -> dict:
    """Get dataset-specific configuration."""
    configs = {
        "mnist": {
            "img_size": 28,
            "in_channels": 1,
            "out_channels": 1,
            "kernel_size_schedule": [
                28,
                23,
                17,
                13,
                9,
                5,
                3,
            ],  # Default for MNIST
        },
        "fashion_mnist": {
            "img_size": 28,
            "in_channels": 1,
            "out_channels": 1,
            "kernel_size_schedule": [
                28,
                23,
                17,
                13,
                9,
                5,
                3,
            ],  # Same as MNIST initially
        },
        "cifar10": {
            "img_size": 32,
            "in_channels": 3,
            "out_channels": 3,
            "kernel_size_schedule": [
                32,
                32,
                32,
                29,
                25,
                17,
                13,
                9,
                7,
                3,
            ],  # Larger kernels for CIFAR10
        },
        "ffhq": {
            "img_size": 64,
            "in_channels": 3,
            "out_channels": 3,
            "kernel_size_schedule": [
                64,
                45,
                33,
                25,
                17,
                9,
                5,
                3,
            ],  # Same as CIFAR10 initially
        },
        "celeba_hq": {
            "img_size": 64,
            "in_channels": 3,
            "out_channels": 3,
            "kernel_size_schedule": [
                64,
                45,
                33,
                25,
                17,
                9,
                5,
                3,
            ],  # Adjusted for 64x64
        },
        "afhq": {
            "img_size": 64,
            "in_channels": 3,
            "out_channels": 3,
            "kernel_size_schedule": [
                64,
                45,
                33,
                25,
                17,
                9,
                5,
                3,
            ],  # Same as CelebA-HQ initially
        },
    }
    return configs[dataset_name]


def get_unet_config(dataset_name: str, num_images: int) -> Dict:
    """Returns UNet configuration based on dataset and size."""
    # Template for model paths - to be replaced with actual paths
    model_paths = {
        "mnist": {
            -1: (
                "trained_models/unet/unet_mnist_-1_noattn_20250513_194551",
                "ckpt_epoch_200.pt",
            ),
            # unet_mnist_-1_noattn_20250513_195201 200
        },
        "fashion_mnist": {
            -1: (
                "trained_models/unet/unet_fashion_mnist_-1_noattn_20250513_194633",
                "ckpt_epoch_200.pt",
            ),
            # unet_fashion_mnist_-1_noattn_20250514_001525 200
        },
        "cifar10": {
            -1: (
                "trained_models/unet/unet_cifar10_-1_noattn_20250313_232926",
                "ckpt_epoch_200.pt",
            ),
            100: (
                "trained_models/unet/unet_cifar10_100_noattn_20250313_232926",
                "ckpt_epoch_70000.pt",
            ),
            1000: (
                "trained_models/unet/unet_cifar10_1000_noattn_20250313_232925",
                "ckpt_epoch_10000.pt",
            ),
            10000: (
                "trained_models/unet/unet_cifar10_10000_noattn_20250312_035606",
                "ckpt_epoch_1000.pt",
            ),
            # unet_cifar10_-1_noattn_20250512_160306 200
        },
        "ffhq": {
            -1: ("trained_models/unet/unet_ffhq_-1_noattn", "ckpt_epoch_200.pt"),
            # broken ?
        },
        "celeba_hq": {
            -1: (
                "trained_models/unet/unet_celeba_hq_-1_noattn_20250514_030749",
                "ckpt_epoch_200.pt",
            ),
            # unet_celeba_hq_-1_noattn_20250514_030841 200
        },
        "afhq": {
            -1: (
                "trained_models/unet/unet_afhq_-1_noattn_20250514_002002",  # broken: unet_afhq_-1_noattn_20250513_195651
                "ckpt_epoch_200.pt",
            ),
            # unet_afhq_-1_noattn_20250515_004233 200
        },
    }

    # Get the closest dataset size (use exact match or next largest size)
    available_sizes = sorted(model_paths[dataset_name].keys())
    if num_images == -1:
        model_size = -1  # Use full dataset model
    else:
        model_size = next(
            (size for size in available_sizes if size >= num_images),
            available_sizes[-1],
        )
    print(
        f"Using trained UNET with dataset size: {'full' if model_size == -1 else model_size}"
    )

    # Get dataset-specific configuration
    dataset_config = get_dataset_config(dataset_name)
    img_size = dataset_config["img_size"]

    # Get model architecture based on image size
    if img_size == 28:  # MNIST, FashionMNIST
        channel = 64
        channel_mult = [1, 2, 2]  # Only 3 downsamples: 28->14->7->3
    elif img_size == 32:  # CIFAR10, FFHQ
        channel = 128
        channel_mult = [1, 2, 3, 4]  # 32->16->8->4->2
    elif img_size == 64:  # CelebA-HQ, AFHQ
        channel = 128
        channel_mult = [1, 2, 3, 4]  # 64->32->16->8->4
    else:
        raise ValueError(f"Unsupported image size: {img_size}")

    return {
        "epoch": 200,
        "batch_size": 32,
        "T": 1000,
        "channel": channel,
        "random_seed": 42,
        "eval_random_seed": 42,
        "subset_size": model_size,
        "channel_mult": channel_mult,
        "attn": [],
        "num_res_blocks": 2,
        "dropout": 0.15,
        "lr": 1e-4,
        "multiplier": 2.0,
        "beta_1": 1e-4,
        "beta_T": 0.02,
        "img_size": img_size,
        "grad_clip": 1.0,
        "device": "cuda",
        "dataset_root": "data/",
        "checkpoint_freq": 20,
        "use_wandb": True,
        "sample_freq": 20,
        "training_load_weight": model_paths[dataset_name][model_size][1],
        "save_weight_dir": model_paths[dataset_name][model_size][0],
        "test_load_weight": model_paths[dataset_name][model_size][1],
        "sampled_dir": "./SampledImgs/",
        "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
        "sampledImgName": "SampledNoGuidenceImgs.png",
        "nrow": 8,
        "model_type": "unet",
        "in_channels": dataset_config["in_channels"],
        "out_channels": dataset_config["out_channels"],
        "dataset_name": dataset_name,
    }
