from eos_line_search.optim import Optim
import torch.nn as nn

optimizers = []
# optimizers.append(Optim(opt_name="MM", momentum=0.9, init_step_size=0.001))
#

for step in [-1]:
    optimizers.append(
        Optim(
            opt_name="SAM", step_size=step, momentum=0.9, weight_decay=0.0001, rho=0.05
        )
    )

for sigma in [2.06]:
    optimizers.append(Optim(opt_name="CDAT", c=sigma))

for ct in [1e-4]:
    # New PoNoS with check for number of backtracks (no limit) plus fix (PoNoS11)
    optimizers.append(
        Optim(
            opt_name="PoNoS",
            init_step_size=1,
            max_step_size=10000,
            c=ct,
            decrease_factor=0.5,
            reset_option=0,
            forward_option=11,
            eps=0,
            adapt_c=0,
            nonmonotone_option=0,
            M=10000,
            adapt_M=0,
            xi=1,
            momentum=0,
        )
    )
    optimizers.append(
        Optim(
            opt_name="PoNoS",
            init_step_size=1,
            max_step_size=10000,
            c=ct,
            decrease_factor=0.5,
            reset_option=0,
            forward_option=10,
            eps=0,
            adapt_c=0,
            nonmonotone_option=0,
            M=10000,
            adapt_M=0,
            xi=1,
            momentum=0,
        )
    )
    # Original PoNoS
    optimizers.append(
        Optim(
            opt_name="PoNoS",
            init_step_size=1,
            max_step_size=10000,
            c=ct,
            decrease_factor=0.5,
            reset_option=1,
            forward_option=0,
            eps=0,
            adapt_c=0,
            nonmonotone_option=0,
            adapt_M=0,
            momentum=0,
        )
    )
    # SLS
    optimizers.append(
        Optim(
            opt_name="SLS",
            init_step_size=1,
            max_step_size=10000,
            c=ct,
            decrease_factor=0.5,
            reset_option=0,
            forward_option=10,
            eps=0,
            momentum=0,
        )
    )

debug_optimizers = []
debug_optimizers.append(
    Optim(
        opt_name="PoNoS",
        init_step_size=1,
        max_step_size=10000,
        c=0.0001,
        decrease_factor=0.5,
        reset_option=0,
        forward_option=10,
        eps=0,
        adapt_c=0,
        nonmonotone_option=0,
        adapt_M=0,
        momentum=1,
    )
)

new_optimizers = []
new_optimizers.append(
    Optim(
        opt_name="PoNoS",
        init_step_size=1,
        max_step_size=10000,
        c=1e-4,
        decrease_factor=0.5,
        reset_option=0,
        forward_option=11,
        eps=0,
        adapt_c=0,
        nonmonotone_option=0,
        M=10000,
        adapt_M=0,
        xi=1,
        momentum=0,
    )
)

best_c = {
    ("PoNoS", 10, "CIFAR10", "CNN"): 0.0001,
    ("PoNoS", 10, "CIFAR10", "resnet34"): 0.1,
    ("PoNoS", 10, "CIFAR10", "vgg11"): 0.0001,
    ("PoNoS", 10, "CIFAR10", "tinyVIT"): 0.01,
    ("PoNoS", 10, "SVHN", "CNN"): 0.1,
    ("PoNoS", 10, "SVHN", "resnet34"): 0.1,
    ("PoNoS", 10, "SVHN", "vgg11"): 0.1,
    ("PoNoS", 10, "SVHN", "tinyVIT"): 0.001,
    ("PoNoS", 9, "CIFAR10", "CNN"): 0.0001,
    ("PoNoS", 9, "CIFAR10", "resnet34"): 0.1,
    ("PoNoS", 9, "CIFAR10", "vgg11"): 0.01,
    ("PoNoS", 9, "CIFAR10", "tinyVIT"): 0.0001,
    ("PoNoS", 9, "SVHN", "CNN"): 0.1,
    ("PoNoS", 9, "SVHN", "resnet34"): 0.1,
    ("PoNoS", 9, "SVHN", "vgg11"): 0.1,
    ("PoNoS", 9, "SVHN", "tinyVIT"): 0.001,
    ("PoNoS", 0, "CIFAR10", "CNN"): 0.001,
    ("PoNoS", 0, "CIFAR10", "resnet34"): 0.1,
    ("PoNoS", 0, "CIFAR10", "vgg11"): 0.0001,
    ("PoNoS", 0, "CIFAR10", "tinyVIT"): 0.01,
    ("PoNoS", 0, "SVHN", "CNN"): 0.1,
    ("PoNoS", 0, "SVHN", "resnet34"): 0.1,
    ("PoNoS", 0, "SVHN", "vgg11"): 0.1,
    ("PoNoS", 0, "SVHN", "tinyVIT"): 0.001,
    ("SAM", 0, "CIFAR10", "CNN"): 0.1,
    ("SAM", 0, "CIFAR10", "resnet34"): 0.01,
    ("SAM", 0, "CIFAR10", "vgg11"): 0.1,
    ("SAM", 0, "CIFAR10", "tinyVIT"): 0.01,
    ("SAM", 0, "CIFAR10", "wide_resnet50_2"): 0.0001,
    ("SAM", 0, "CIFAR100", "CNN"): 0.01,
    ("SAM", 0, "CIFAR100", "resnet34"): 0.01,
    ("SAM", 0, "CIFAR100", "vgg11"): 0.01,
    ("SAM", 0, "CIFAR100", "wide_resnet50_2"): 0.001,
    ("SAM", 0, "CIFAR100", "tinyVIT"): 0.01,
    ("SAM", 0, "SVHN", "CNN"): 0.1,
    ("SAM", 0, "SVHN", "resnet34"): 0.001,
    ("SAM", 0, "SVHN", "vgg11"): 0.1,
    ("SAM", 0, "SVHN", "tinyVIT"): 0.01,
    ("SAM", 0, "SVHN", "wide_resnet50_2"): 0.0001,
    ("SAM", 0, "EMNIST", "CNN"): 0.01,
    ("SAM", 0, "EMNIST", "resnet34"): 0.001,
    ("SAM", 0, "EMNIST", "vgg11"): 0.1,
    ("SAM", 0, "EMNIST", "tinyVIT"): 0.01,
    ("SAM", 0, "EMNIST", "wide_resnet50_2"): 0.0001,
    ("CDAT", 0, "CIFAR10", "CNN"): 2.5,
    ("CDAT", 0, "CIFAR10", "resnet34"): 2.5,
    ("CDAT", 0, "CIFAR10", "vgg11"): 2.06,
    ("CDAT", 0, "SVHN", "CNN"): 2.5,
    ("CDAT", 0, "SVHN", "resnet34"): 2.5,
    ("CDAT", 0, "SVHN", "vgg11"): 2.06,
}

assmpt_optimizers = [
    Optim(
        opt_name="PoNoS",
        init_step_size=1,
        max_step_size=10000,
        c=0.0001,
        decrease_factor=0.5,
        reset_option=0,
        forward_option=10,
        eps=0,
        adapt_c=0,
        nonmonotone_option=0,
        adapt_M=0,
        momentum=2,
    )
]

approx_optimizers = [
    Optim(
        opt_name="PoNoS",
        init_step_size=1,
        max_step_size=10000,
        c=0.0001,
        decrease_factor=0.5,
        reset_option=0,
        forward_option=10,
        eps=0,
        adapt_c=0,
        nonmonotone_option=0,
        adapt_M=0,
        momentum=3,
    )
]

sam_optimizers = []

for step in [0.0001, 0.001, 0.01, 0.1]:
    sam_optimizers.append(
        Optim(
            opt_name="SAM", step_size=step, momentum=0.9, weight_decay=0.0001, rho=0.05
        )
    )

delta_ablation_optimizers = []

for ct in [1e-4]:
    # New PoNoS with check for number of backtracks (no limit) plus fix (PoNoS11)
    delta_ablation_optimizers.append(
        Optim(
            opt_name="PoNoS",
            init_step_size=1,
            max_step_size=10000,
            c=ct,
            decrease_factor=0.9,
            reset_option=0,
            forward_option=11,
            eps=0,
            adapt_c=0,
            nonmonotone_option=0,
            M=10000,
            adapt_M=0,
            xi=1,
            momentum=4,
        )
    )
    delta_ablation_optimizers.append(
        Optim(
            opt_name="PoNoS",
            init_step_size=1,
            max_step_size=10000,
            c=ct,
            decrease_factor=0.9,
            reset_option=0,
            forward_option=10,
            eps=0,
            adapt_c=0,
            nonmonotone_option=0,
            M=10000,
            adapt_M=0,
            xi=1,
            momentum=4,
        )
    )
    # Original PoNoS
    delta_ablation_optimizers.append(
        Optim(
            opt_name="PoNoS",
            init_step_size=1,
            max_step_size=10000,
            c=ct,
            decrease_factor=0.9,
            reset_option=1,
            forward_option=0,
            eps=0,
            adapt_c=0,
            nonmonotone_option=0,
            adapt_M=0,
            momentum=4,
        )
    )
    # SLS
    delta_ablation_optimizers.append(
        Optim(
            opt_name="SLS",
            init_step_size=1,
            max_step_size=10000,
            c=ct,
            decrease_factor=0.9,
            reset_option=0,
            forward_option=10,
            eps=0,
            momentum=4,
        )
    )

vit_optimizers = []

for xi_v in [0.5, 0.6, 0.7, 0.8, 0.9]:
    # New PoNoS with check for number of backtracks (no limit) (PoNoS10)
    vit_optimizers.append(
        Optim(
            opt_name="PoNoS",
            init_step_size=1,
            max_step_size=10000,
            c=1e-4,
            decrease_factor=0.5,
            reset_option=0,
            forward_option=10,
            eps=0,
            adapt_c=0,
            nonmonotone_option=4,
            M=10000,
            adapt_M=0,
            xi=xi_v,
            momentum=5,
        )
    )

warmup_optimizers = []
warmup_optimizers.append(Optim(opt_name="warmup_GD_small", step_size=0.01, momentum=0))
warmup_optimizers.append(Optim(opt_name="warmup_GD", step_size=0.01, momentum=0))

# all_models = {
#    "convnext_tiny": Model(model_type="convnext_tiny"),
#    "efficientnet": Model(model_type="efficientnet"),
#    "densenet121": Model(model_type="densenet121"),
#    "tinyVIT": Model(model_type="tinyVIT"),
#    "swin_t": Model(model_type="swin_t"),
#    "efficientnet_v2_s": Model(model_type="efficientnet_v2_s"),
#    "maxvit_t": Model(model_type="maxvit_t"),
#    "wide_resnet50_2": Model(model_type="wide_resnet50_2"),
#    "resnet34": Model(model_type="resnet34"),
#    "resnet34-leakyrelu": Model(model_type="resnet34-leakyrelu"),
#    "resnet34-init": Model(model_type="resnet34-init"),
#    "vgg11": Model(model_type="vgg11"),
#    "CNN": CNNModel(
#        model_type="CNN", activation_fn=nn.ReLU, pooling=nn.MaxPool2d, window_size=2
#    ),
#    "MLP": MLPModel(model_type="MLP", activation_fn=nn.ReLU, num_layers=3, width=100),
# }

GD_steps = {
    ("CIFAR10", "MLP"): [2 / 10, 2 / 20, 2 / 30],
    ("CIFAR10", "CNN"): [2 / 10, 2 / 20, 2 / 30],
    ("CIFAR10", "resnet34"): [2 / 200, 2 / 400, 2 / 600],  # Maybe smaller
    ("CIFAR10", "vgg11"): [2 / 50, 2 / 100, 2 / 150],
    ("CIFAR10", "tinyVIT"): [2 / 200],  # Not sure
    ("SVHN", "CNN"): [2 / 20, 2 / 50, 2 / 80],
    ("SVHN", "resnet34"): [2 / 500, 2 / 1000, 2 / 1500],  # Maybe smaller
    ("SVHN", "vgg11"): [2 / 20, 2 / 55, 2 / 90],
    ("SVHN", "tinyVIT"): [2 / 200],  # Not sure
    ("Imagenette", "CNN"): [2 / 20],  # Not sure
    ("Imagenette", "resnet34"): [2 / 1000],  # Not sure
    ("Imagenette", "vgg11"): [2 / 2000],  # Not sure
    ("Imagenette", "tinyVIT"): [2 / 100],  # Not sure
}
