import traceback
from dataclasses import dataclass
from enum import Enum
from typing import List, Dict

# noinspection PyUnresolvedReferences
import XXX.notebook
import dataclasses
import numpy as np
import torch
from ignite.engine import Events

import experiments.datasets.cifar10 as dataset_cifar10
import experiments.models.cifar10 as model_cifar10
from experiments.models.deterministic.resnet_v2 import resnet18_v2

import experiments.models.zero_entropy_noise
import experiments.utils.ignite_output
from XXX.uib.losses import cross_entropies, very_approx_regularizers
from XXX.uib import kraskov_general_continuous_iq_loss
from experiments.dynamics import dynamics
from experiments.models import stochastic_model
from experiments.utils.experiment_YYY import embedded_experiments
from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput

Regularizers = dict(
    mean_squared_Z=very_approx_regularizers.squared_sum(stochastic=True),
    entropy_via_variance_Z__Y=very_approx_regularizers.estimate_entropy_Z__Y(stochastic=True),
    entropy_via_variance_Z=very_approx_regularizers.estimate_entropy(
        very_approx_regularizers.covariance_trace, stochastic=True
    ),
    kraskov_Z__Y=kraskov_general_continuous_iq_loss.iq_loss(kraskov_general_continuous_iq_loss.iq.H_Z__Y),
    kraskov_Z=kraskov_general_continuous_iq_loss.iq_loss(kraskov_general_continuous_iq_loss.iq.H_Z),
    weight_decay=lambda latent, label: torch.zeros(1, device=latent.device)
)


@dataclass
class Experiment:
    seed: int
    num_samples: int
    gamma: float
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int
    latent_layer_name: str
    regularizer: str
    capacity: int
    snapshot_every: int
    weight_decay: float
    inject_noise: bool

    def run(self, store):
        gamma = self.gamma

        torch.backends.cudnn.benchmark = True

        torch.manual_seed(self.seed)
        dataloaders = dataset_cifar10.dataloaders(
            self.train_batch_size, self.val_batch_size, train_only=False, augmentation=True, validation_size=0
        )

        model = self.create_model()
        model.cuda()

        optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=self.weight_decay)

        if self.inject_noise:
            noise_injector = experiments.models.zero_entropy_noise.StochasticInjectZeroEntropyNoise(self.num_samples)
            noise_injector.cuda()
        else:
            noise_injector = None

        loss_latent_extractor = dynamics.LatentExtractor(
            model.wrapped_model, layer_name=self.latent_layer_name, noise_injector=noise_injector
        )

        regularizer_func = Regularizers[self.regularizer]

        def full_stochastic_loss(latent_x_k_z, prediction_x_k_y, labels_x):
            loss = cross_entropies.stochastic_continuous_encoder_decoder_cross_entropy(prediction_x_k_y, labels_x)
            if gamma != 0.0:
                loss = loss + regularizer_func(latent_x_k_z, labels_x) * gamma
            return loss

        dynamics_instance = dynamics.StochasticContinuousFullLossDynamics(
            model, optimizer, full_stochastic_loss, latent_extractor=loss_latent_extractor
        )

        scheduler = ReduceLROnPlateauWrapper(
            dynamics_instance.optimizer,
            output_transform=IgniteOutput.get_loss,
            mode="min",
            patience=10,
            factor=0.1 ** 0.5,
            threshold_mode="rel",
            verbose=True,
            min_lr=1e-7,
        )
        schedulers = [scheduler]

        store["snapshots"] = {}
        snapshot_store = store["snapshots"]

        def save_snapshot(epoch):
            global store
            filename = "/scratch/ZZZ/gigi/models/edl/snapshots/" \
                       f"cifar10_no_dropout_{store['job_id']}_{epoch}_{store['timestamp']}.pyt"
            torch.save(model.state_dict(), filename)
            snapshot_store[epoch] = filename
            print(f"Epoch {epoch} --> {filename}")

        # Always save the initial weights as well.
        save_snapshot(0)

        def extra_hooks(trainer, evaluator, train_evaluator, validator):
            def save_snapshot_event(engine):
                save_snapshot(engine.state.epoch)

            trainer.add_event_handler(Events.EPOCH_COMPLETED(every=self.snapshot_every), save_snapshot_event)

        run_common_experiment(
            dynamics=dynamics_instance,
            dataloaders=dataloaders,
            max_epochs=self.epochs,
            in_capacity=self.capacity,
            out_capacity=10,
            train_lr_schedulers=schedulers,
            log_interval=self.log_interval,
            store=store,
            seed=self.seed,
            stochastic=True,
            discrete_summary=False,
            continuous_summary=False,
            device="cuda",
            mean_l2_squared=False,
            minibatch_entropies=False,
            train_eval=False,
            zero_eval=False,
            extra_hooks=extra_hooks
        )

    def create_model(self):
        model = stochastic_model.as_stochastic_model(
            torch.nn.Sequential(
                model_cifar10.DeterministicCifar10Resnet(resnet_factory=resnet18_v2, capacity=self.capacity),
                torch.nn.Linear(self.capacity, 10),
            )
        )
        return model


gammas_Z2 = np.geomspace(1e-6, 1e1, num=20, endpoint=False)
gammas_Z = np.geomspace(1e-5, 1, num=20, endpoint=True)
gammas_ZY = np.geomspace(1e-5, 1e1, num=20, endpoint=True)
gammas_kraskov_Z = np.geomspace(1e-5, 1, num=20, endpoint=True)
gammas_kraskov_ZY = np.geomspace(1e-5, 1e1, num=20, endpoint=True)
gammas_weight_decay = np.geomspace(1e-6, 1e1, num=20, endpoint=False)


def create_simple_surrogate_experiment(seed, gamma, regularizer):
    return Experiment(
        seed=seed,
        gamma=gamma,
        num_samples=1,
        train_batch_size=128,
        val_batch_size=512,
        epochs=150,
        log_interval=10,
        regularizer=regularizer,
        latent_layer_name="0",
        capacity=256,
        snapshot_every=5,
        weight_decay=0,
        inject_noise=True
    )


def create_weight_decay_surrogate_experiment(seed, gamma, regularizer):
    return Experiment(
        seed=seed,
        gamma=gamma,
        num_samples=1,
        train_batch_size=128,
        val_batch_size=512,
        epochs=150,
        log_interval=10,
        regularizer=regularizer,
        latent_layer_name="0",
        capacity=256,
        snapshot_every=5,
        weight_decay=gamma,
        inject_noise=True
    )


def create_kraskov_surrogate_experiment(seed, gamma, regularizer):
    return Experiment(
        seed=seed,
        gamma=gamma,
        num_samples=1,
        train_batch_size=1024,
        val_batch_size=512,
        epochs=150,
        log_interval=10,
        regularizer=regularizer,
        latent_layer_name="0",
        capacity=256,
        snapshot_every=5,
        weight_decay=0,
        inject_noise=True
    )


all_experiments = (
        [create_simple_surrogate_experiment(770047 + i * 31, gamma, "mean_squared_Z") for i, gamma in
         enumerate(gammas_Z2)]
        + [create_simple_surrogate_experiment(804859 + i * 31, gamma, "entropy_via_variance_Z__Y") for i, gamma in
           enumerate(gammas_ZY)]
        + [create_simple_surrogate_experiment(694773 + i * 31, gamma, "entropy_via_variance_Z") for i, gamma in
           enumerate(gammas_Z)]
        # + [create_kraskov_surrogate_experiment(189643 + i * 31, gamma, "kraskov_Z__Y") for i, gamma in
        #    enumerate(gammas_ZY)]
        # + [create_kraskov_surrogate_experiment(986845 + i * 31, gamma, "kraskov_Z") for i, gamma in
        #    enumerate(gammas_Z)]
)
all_experiments.extend([create_weight_decay_surrogate_experiment(986845 + i * 31, gamma, "weight_decay") for i, gamma in
                        enumerate(gammas_weight_decay)])
all_experiments.append(dataclasses.replace(all_experiments[0], seed=468135, gamma=0.))
all_experiments.append(dataclasses.replace(all_experiments[0], seed=468135, gamma=0., inject_noise=False))

if __name__ == "__main__":
    for job_id, store in embedded_experiments(__file__, len(all_experiments)):
        experiment = all_experiments[job_id]
        experiment.seed += job_id
        print(experiment)

        store["experiment"] = experiment
        store["log"] = {}

        try:
            experiment.run(store=store["log"])
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
