from adversarialML.biologically_inspired_models.src.imagenet_mlp_mixer_tasks_commons import *
class Imagenet75_32x32_AutoAugmentMLPMixer8L1xWide00Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L1xWide00Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L1xWide00Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L1xWide00Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L1xWide01Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L1xWide01Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L1xWide01Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L1xWide01Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L2xWide00Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L2xWide00Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L2xWide00Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L2xWide00Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L2xWide01Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L2xWide01Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L2xWide01Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L2xWide01Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L4xWide00Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L4xWide00Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L4xWide00Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L4xWide00Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.0, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L4xWide01Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L4xWide01Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L4xWide01Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer8L4xWide01Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.1, 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L1xWide00Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L1xWide00Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L1xWide00Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L1xWide00Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L1xWide01Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L1xWide01Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L1xWide01Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L1xWide01Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 1*128, 1*512, 1*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L2xWide00Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L2xWide00Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L2xWide00Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L2xWide00Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L2xWide01Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L2xWide01Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L2xWide01Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L2xWide01Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 2*128, 2*512, 2*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L4xWide00Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L4xWide00Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L4xWide00Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L4xWide00Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.0, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L4xWide01Dropout5e_5WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=5e-5),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L4xWide01Dropout1e_4WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L4xWide01Dropout1e_3WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-3),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )


class Imagenet75_32x32_AutoAugmentMLPMixer12L4xWide01Dropout1e_2WDAdamLinearWarmupDecayTask(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet75_64_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(32, 4, torchvision.transforms.AutoAugmentPolicy.CIFAR10),
            test_transforms=[torchvision.transforms.Resize(32), torchvision.transforms.CenterCrop(32)]
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,32,32], 75, 4, 4*128, 4*512, 4*64, nn.GELU, 0.1, 12)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_apgd_testing_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-2),
            CyclicLRConfig(base_lr=5e-5, max_lr=0.002, step_size_up=741*10, step_size_down=741*290, cycle_momentum=False),
            128,
            num_training=5
        )

