from copy import deepcopy
from re import L
from time import time
from typing import List, Tuple, Type

import torchvision
from adversarialML.biologically_inspired_models.src.mlp_mixer_models import (
    ConsistentActivationMixerBlock, ConsistentActivationMixerMLP,
    FirstNExtractionClassifier, LinearLayer, MixerBlock, MixerMLP, MLPMixer,
    NormalizationLayer, UnfoldPatchExtractor)
from adversarialML.biologically_inspired_models.src.models import (
    ConsistentActivationLayer, ConvEncoder, GeneralClassifier, IdentityLayer,
    ScanningConsistentActivationLayer, SequentialLayers, XResNet34, XResNet18, SupervisedContrastiveTrainingWrapper)
from adversarialML.biologically_inspired_models.src.retina_preproc import (
    AbstractRetinaFilter, RetinaBlurFilter, RetinaNonUniformPatchEmbedding,
    RetinaSampleFilter)
from adversarialML.biologically_inspired_models.src.supconloss import \
    TwoCropTransform
from adversarialML.biologically_inspired_models.src.trainers import (
    ActivityOptimizationSchedule, AdversarialParams, AdversarialTrainer,
    ConsistentActivationModelAdversarialTrainer,
    MixedPrecisionAdversarialTrainer)
from adversarialML.biologically_inspired_models.src.mlp_mixer_tasks import get_dataset_params
from mllib.adversarial.attacks import (AttackParamFactory, SupportedAttacks,
                                       SupportedBackend)
from mllib.datasets.dataset_factory import (ImageDatasetFactory,
                                            SupportedDatasets)
from mllib.models.base_models import MLP
from mllib.optimizers.configs import (AbstractOptimizerConfig, AbstractSchedulerConfig, AdamOptimizerConfig,
                                      CosineAnnealingWarmRestartsConfig,
                                      CyclicLRConfig, LinearLRConfig,
                                      ReduceLROnPlateauConfig,
                                      SequentialLRConfig, SGDOptimizerConfig)
from mllib.runners.configs import BaseExperimentConfig, TrainingParams
from mllib.tasks.base_tasks import AbstractTask
from torch import nn
from mllib.adversarial.attacks import TorchAttackAPGDInfParams

from mlp_mixer_tasks import get_resize_crop_flip_autoaugment_transforms

_LOGDIR = '/share/workhorse3/hippo/biologically_inspired_models/logs/'
_EPS_LIST = [0.0, 0.008, 0.016, 0.024, 0.032, 0.048, 0.064]
_NEPOCHS = 300
_PATIENCE = 50
_APGD_STEPS = 50

def get_imagenet10_params(num_train=13_000, num_test=500, train_transforms=[], test_transforms=[]):
    return get_dataset_params('/home/hippo/workhorse3/imagenet-100/bin/64', SupportedDatasets.IMAGENET10, 
                                num_train, num_test, train_transforms, test_transforms)

def get_imagenet100_params(num_train=25, num_test=1, train_transforms=[], test_transforms=[]):
    return get_dataset_params('/home/hippo/workhorse3/imagenet-100/bin/64', SupportedDatasets.IMAGENET100, 
                                num_train, num_test, train_transforms, test_transforms)

def get_imagenet100_64_params(num_train=127500, num_test=5000, train_transforms=[], test_transforms=[]):
    return get_dataset_params('/home/hippo/workhorse3/imagenet-100/bin/64', SupportedDatasets.IMAGENET100_64, 
                                num_train, num_test, train_transforms, test_transforms)

def get_conv_patch_extractor_params(input_size, hidden_size, patch_size):
    patch_params: ConvEncoder.ModelParams = ConvEncoder.get_params()
    patch_params.common_params.input_size = input_size
    patch_params.common_params.num_units = [hidden_size]
    patch_params.common_params.activation = nn.Identity
    patch_params.conv_params.kernel_sizes = [patch_size]
    patch_params.conv_params.padding = [0]
    patch_params.conv_params.strides = [patch_size]
    npatches = (input_size[0] // patch_size)*(input_size[1] // patch_size)
    return patch_params, npatches

def get_basic_mixer_mlp_params(activation, dropout_p, input_size, hidden_size):
    mlp_params: MixerMLP.ModelParams = MixerMLP.get_params()
    mlp_params.common_params.activation = activation
    mlp_params.common_params.dropout_p = dropout_p
    mlp_params.common_params.input_size = [input_size]
    mlp_params.common_params.num_units = hidden_size
    return mlp_params

def get_basic_mixer_block_params(mlpc_params, mlps_params, num_patches, hidden_size):
    block_params: MixerBlock.ModelParams = MixerBlock.get_params()
    block_params.channel_mlp_params = mlpc_params
    block_params.spatial_mlp_params = mlps_params
    block_params.common_params.input_size = [num_patches, hidden_size]
    return block_params

def get_linear_classifier_params(hidden_size, nclasses):
    cls_params: LinearLayer.ModelParams = LinearLayer.get_params()
    cls_params.common_params.input_size = hidden_size
    cls_params.common_params.num_units = nclasses
    cls_params.common_params.activation = nn.Identity
    return cls_params

def get_mlp_mixer_params(input_size, patch_params, cls_params, mixer_block_params):
    mixer_params: MLPMixer.ModelParams = MLPMixer.get_params()
    mixer_params.common_params.input_size = input_size
    mixer_params.patch_gen_params = patch_params
    mixer_params.mixer_block_params = mixer_block_params
    mixer_params.classifier_params = cls_params
    return mixer_params

def get_basic_mlp_mixer_params(input_size, nclasses, patch_size, hidden_size, mlpc_hidden, mlps_hidden, activation, dropout_p, num_blocks):
    patch_params, num_patches = get_conv_patch_extractor_params(input_size, hidden_size, patch_size)
    mlpc_params = get_basic_mixer_mlp_params(activation, dropout_p, hidden_size, mlpc_hidden)
    mlps_params = get_basic_mixer_mlp_params(activation, dropout_p, num_patches, mlps_hidden)
    cls_params = get_linear_classifier_params(hidden_size, nclasses)
    mixer_block_params = get_basic_mixer_block_params(mlpc_params, mlps_params, num_patches, hidden_size)
    mlp_mixer_params = get_mlp_mixer_params(input_size, patch_params, cls_params, [mixer_block_params]*num_blocks)
    return mlp_mixer_params

def get_adv_experiment_params(trainer_cls: Type[AdversarialTrainer], training_params: TrainingParams, adv_params:AdversarialParams,
                                optimizer_config:AbstractOptimizerConfig, scheduler_config: AbstractSchedulerConfig, batch_size: int,
                                exp_name: str = '', num_training=5):
    if isinstance(scheduler_config, CyclicLRConfig):
        training_params.scheduler_step_after_epoch = False
    p = BaseExperimentConfig(
        trainer_params=trainer_cls.TrainerParams(
            trainer_cls,
            training_params=training_params,
            adversarial_params=adv_params
        ),
        optimizer_config=optimizer_config,
        scheduler_config=scheduler_config,
        batch_size=batch_size,
        logidr=_LOGDIR,
        exp_name=exp_name,
        num_trainings=num_training
    )
    return p

def get_apgd_inf_params(eps_list, nsteps, eot_iters=1):
    return [TorchAttackAPGDInfParams(eps, nsteps, eot_iter=eot_iters, seed=time()) for eps in eps_list]

def get_common_training_params():
    return TrainingParams(
        logdir=_LOGDIR, nepochs=_NEPOCHS, early_stop_patience=_PATIENCE, tracked_metric='val_acc', tracking_mode='max'
    )

def get_common_adversarial_params():
    return AdversarialParams(
        testing_attack_params=get_apgd_inf_params(_EPS_LIST, _APGD_STEPS)
    )

class Imagenet10AutoAugmentMLPMixer8LWD1e_4Task(AbstractTask):
    def get_dataset_params(self):
        return get_imagenet10_params(
            train_transforms=get_resize_crop_flip_autoaugment_transforms(64, 8, torchvision.transforms.AutoAugmentPolicy.IMAGENET)
        )
    
    def get_model_params(self):
        return get_basic_mlp_mixer_params([3,64,64], 8, 128, 512, 64, nn.GELU, 0., 8)
    
    def get_experiment_params(self) -> BaseExperimentConfig:
        return get_adv_experiment_params(
            MixedPrecisionAdversarialTrainer,
            get_common_training_params(),
            get_common_adversarial_params(),
            AdamOptimizerConfig(weight_decay=1e-4),
            CyclicLRConfig(base_lr=1e-6, max_lr=0.001, step_size_up=250*30, step_size_down=250*270),
            512
        )
    