from typing import List

import XXX.notebook

import YYY

from experiments.runs.adversarial_attacks import fgsm

# Bc we are reloading results to find the snapshots
from experiments.datasets import DataLoaders
from torch.utils.data import DataLoader
from experiments.models.permutation_mnist import DeterministicModelResFC
from experiments.utils.jupyter import results_loader

import dataclasses
import traceback
from dataclasses import dataclass

from ignite.engine import Events
from ignite.handlers import EarlyStopping

from XXX.uib.information_quantities import decoder_uncertainty, reverse_decoder_uncertainty, H_Z, H_Y__Z
from XXX.uib.losses import cross_entropies
from XXX.uib import kraskov_continuous_iq_loss

import experiments.models.cifar10 as model_cifar10
from XXX.uib.modules.cluster_decoder import ClusterDecoder, GaussianMixtureDecoder
from XXX.uib.modules.decoder_interface import PassthroughDecoder
from XXX.uib.modules.encoder_decoder import EncoderDecoder
from experiments.models import ic_resnet_v2
from experiments.models import dropconnect_resnet
from experiments.models import dropconnect_resnet_v2, stochastic_model

import experiments.datasets.cifar10 as dataset_cifar10
from experiments.dynamics.dynamics import DecodingLossDynamics, EncodingLossDynamics, LatentExtractor, \
    StochasticContinuousDynamics, TwoLossesDynamics
from experiments.models.zero_entropy_noise import InjectZeroEntropyNoise, StochasticInjectZeroEntropyNoise

from experiments.utils.experiment_YYY import embedded_experiments

from XXX.progress_bar import with_progress_bar

import torch
import numpy as np

import experiments.runs.iclr_experiments.cifar10_no_dropout_surrogates_training as base_experiment

from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper, EarlyExitCriterion
from experiments.utils.ignite_output import IgniteOutput

from foolbox import PyTorchModel, accuracy, samples
import foolbox.attacks as fa
import numpy as np
from experiments.models import stochastic_model


@dataclass
class Config:
    seed: int
    batch_size: int
    job_id: int
    epsilons: List[float]

    def run(self, store):
        torch.backends.cudnn.benchmark = True
        torch.manual_seed(self.seed)

        dataloaders = dataset_cifar10.dataloaders(
            self.batch_size,
            self.batch_size,
            train_only=False,
            augmentation=True,
            normalize=False,
            validation_size=0,
        )

        # Load results
        loaded_results = results_loader.load_YYY_files('src/experiments/runs/iclr_experiments/results_no_dropout')
        results = results_loader.filter_dict(loaded_results, v=lambda result: result.job_id == self.job_id)

        assert len(results) == 1

        result = results_loader.get_any(results)
        config = base_experiment.Experiment(**result.experiment._asdict())

        model = config.create_model()

        if config.inject_noise:
            noise_injector = StochasticInjectZeroEntropyNoise(1)
            noise_injector.cuda()
        else:
            noise_injector = None

        LatentExtractor(
            model.wrapped_model, layer_name=config.latent_layer_name, noise_injector=noise_injector
        )

        import pprint
        pprint.pprint(config)

        preprocessing = dict(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010], axis=-3)

        attacks = {
            "FGSM": fa.FGSM(),
            "PGD": fa.LinfPGD(),
            "BasicIterative": fa.LinfBasicIterativeAttack(),
            #"CarliniWagner": fa.L2CarliniWagnerAttack(),
            "DeepFool": fa.LinfDeepFoolAttack(),
            #"BrendelBethge": fa.LinfinityBrendelBethgeAttack()
        }

        store["epochs"] = {}

        #for epoch, snapshot_name in result.log.snapshots.items():
        epoch = 150
        snapshot_name = result.log.snapshots[epoch]

        store["epochs"][epoch] = {}
        epoch_accs = store["epochs"][epoch]

        model.load_state_dict(torch.load(snapshot_name))
        model.cuda()
        model.eval()
        fmodel = PyTorchModel(stochastic_model.AsDeterministicModel(model), bounds=(0, 1),
                              preprocessing=preprocessing)

        epoch_accs["attack_accs"] = {}
        attack_accs = epoch_accs["attack_accs"]

        attack_success = torch.zeros((len(attacks), len(self.epsilons), len(dataloaders.test.dataset)), dtype=torch.bool)
        for i, (attack_name, attack) in enumerate(attacks.items()):
            current_batch_index = 0
            for images, labels in with_progress_bar(dataloaders.test, unit_scale=self.batch_size):
                images, labels = images.cuda(), labels.cuda()

                _, _, success = attack(fmodel, images, labels, epsilons=self.epsilons)
                # assert success.shape == (len(epsilons), len(images))
                current_batch_size = success.shape[1]
                # success_ = success.cpu().numpy()
                # assert success_.dtype == np.bool
                attack_success[i, :, current_batch_index: current_batch_index + current_batch_size] = success
                # print("+")

                current_batch_index += current_batch_size

            attack_accs[attack_name] = 1.0 - attack_success[i].float().mean(dim=-1)

        epoch_accs["robust_accs"] = 1.0 - attack_success.float().max(dim=0)[0].mean(dim=-1)


configs = [
    Config(
        seed=458326 + job_id * 31,
        job_id=job_id,
        batch_size=1024,
        epsilons=[
            0.0,
            0.0005,
            0.001,
            0.0015,
            0.002,
            0.003,
            0.005,
            0.01,
            0.02,
            0.03,
            0.05,
            0.1,
            0.2,
            0.3,
            0.5,
            1.0,
        ],
    )
    for job_id in range(82)
]


if __name__ == "__main__":
    # import pprint

    # pprint.pprint(configs)

    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store["log"])
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
