"""Contains cases for the experiment."""

from math import log10

from torch import logspace

SYNTHETIC_SHALLOW = {  # small enough to compute the Gram matrix explicitly
    "data_name": "synthetic_1d_regression",
    "model_name": "shallow_relu",
    "widths": logspace(1, 7, 20).int().unique(),
    "num_initializations": 3,
    "num_perturbations": 5,  # for close-to-linearity condition
    "epsilon": 1e-7,  # Gram matrix damping
}
SYNTHETIC_DEEP = {  # small enough to compute the Gram matrix explicitly
    "data_name": "synthetic_1d_regression",
    "model_name": "deep_relu_5",
    "widths": logspace(1, 4, 15).int().unique(),
    "num_initializations": 3,
    "num_perturbations": 5,  # for close-to-linearity condition
    "epsilon": 1e-7,  # Gram matrix damping
}
SYNTHETIC_LESS_DEEP = {  # small enough to compute the Gram matrix explicitly
    "data_name": "synthetic_1d_regression",
    "model_name": "deep_relu_3",
    "widths": logspace(1, 4, 15).int().unique(),
    "num_initializations": 3,
    "num_perturbations": 5,  # for close-to-linearity condition
    "epsilon": 1e-7,  # Gram matrix damping
}
CIFAR10SUB400_WRN = {  # small enough to compute the Gram matrix explicitly
    "data_name": "cifar10_subset_400",
    "model_name": "wideresnet",
    "widths": logspace(1, log10(40), 10).int().unique(),
    "num_initializations": 3,
    "num_perturbations": 5,  # for close-to-linearity condition
    "epsilon": 1e-7,  # Gram matrix damping
}
CIFAR10_WRN = {
    "data_name": "cifar10",
    "model_name": "wideresnet",
    "widths": logspace(1, log10(40), 10).int().unique(),
    "num_initializations": 3,
    "num_perturbations": 5,  # for close-to-linearity condition
}
CIFAR100_WRN = {
    "data_name": "cifar100",
    "model_name": "wideresnet",
    "widths": logspace(1, log10(40), 10).int().unique(),
    "num_initializations": 3,
    "num_perturbations": 5,  # for close-to-linearity condition
}

CASES = [
    SYNTHETIC_SHALLOW,
    SYNTHETIC_LESS_DEEP,
    SYNTHETIC_DEEP,
    CIFAR10SUB400_WRN,
    # CIFAR10_WRN,
    # CIFAR100_WRN,
]

CASES_VARY_DATA = [
    {
        **CIFAR10SUB400_WRN,
        **{"data_name": f"cifar10_subset_{num_data}"},
    }
    for num_data in [400, 2_000]
]
