import traceback
from dataclasses import dataclass
from enum import Enum
from typing import List, Dict

# noinspection PyUnresolvedReferences
import XXX.notebook
import dataclasses
import numpy as np
import torch

import experiments.datasets.cifar10 as dataset_cifar10
import experiments.models.cifar10 as model_cifar10
from experiments.models import dropconnect_resnet_v2

import experiments.models.zero_entropy_noise
import experiments.utils.ignite_output
from XXX.uib.losses import cross_entropies, very_approx_regularizers
from experiments.dynamics import dynamics
from experiments.models import stochastic_model
from experiments.utils.experiment_YYY import embedded_experiments
from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput

Regularizers = dict(
    mean_squared_Z=very_approx_regularizers.squared_sum(stochastic=True),
    entropy_via_variance_Z__Y=very_approx_regularizers.estimate_entropy_Z__Y(stochastic=True),
    entropy_via_variance_Z=very_approx_regularizers.estimate_entropy(
        very_approx_regularizers.covariance_trace, stochastic=True
    ),
)


@dataclass
class Experiment:
    seed: int
    num_samples: int
    gamma: float
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int
    latent_layer_name: str
    regularizer: str
    capacity: int

    def run(self, store):
        gamma = self.gamma

        torch.backends.cudnn.benchmark = True

        torch.manual_seed(self.seed)
        dataloaders = dataset_cifar10.dataloaders(
            self.train_batch_size, self.val_batch_size, train_only=False, augmentation=True, validation_size=0
        )

        model = stochastic_model.as_stochastic_model(
            torch.nn.Sequential(
                model_cifar10.StochasticCifar10Resnet(
                    resnet_factory=dropconnect_resnet_v2.resnet18_v2,
                    capacity=self.capacity,
                    num_samples=self.num_samples,
                    dropout_rate=0.1,
                    fc_dropout_rate=0.1,
                ),
                torch.nn.Linear(self.capacity, 10),
            )
        )
        model.cuda()

        optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

        loss_latent_extractor = dynamics.LatentExtractor(
            model.wrapped_model, layer_name=self.latent_layer_name, noise_injector=None
        )

        regularizer_func = Regularizers[self.regularizer]

        def full_stochastic_loss(latent_x_k_z, prediction_x_k_y, labels_x):
            loss = cross_entropies.stochastic_continuous_encoder_decoder_cross_entropy(prediction_x_k_y, labels_x)
            if gamma != 0.0:
                loss = loss + torch.nn.functional.relu(regularizer_func(latent_x_k_z, labels_x)) * gamma
            return loss

        dynamics_instance = dynamics.StochasticContinuousFullLossDynamics(
            model, optimizer, full_stochastic_loss, latent_extractor=loss_latent_extractor
        )

        scheduler = ReduceLROnPlateauWrapper(
            dynamics_instance.optimizer,
            output_transform=IgniteOutput.get_loss,
            mode="min",
            patience=10,
            factor=0.1 ** 0.5,
            threshold_mode="rel",
            verbose=True,
            min_lr=1e-7,
        )
        schedulers = [scheduler]

        run_common_experiment(
            dynamics=dynamics_instance,
            dataloaders=dataloaders,
            max_epochs=self.epochs,
            in_capacity=self.capacity,
            out_capacity=10,
            train_lr_schedulers=schedulers,
            log_interval=self.log_interval,
            store=store,
            seed=self.seed,
            stochastic=True,
            discrete_summary=False,
            continuous_summary=dict(
                train_evaluator=dict(fixed_Z__X=0.0, k=1, num_splits=6),
                evaluator=dict(fixed_Z__X=0.0, k=1, num_splits=6),
                validator=dict(fixed_Z__X=0.0, k=1, num_splits=6),
            ),
            device="cuda",
            mean_l2_squared=True,
            minibatch_entropies=True,
            train_eval=True,
            zero_eval=True,
        )


if __name__ == "__main__":
    gammas_Z2 = np.geomspace(1e-6, 1e1, num=10, endpoint=False)
    gammas_Z = np.geomspace(1e-5, 1, num=10, endpoint=True)
    gammas_ZY = np.geomspace(1e-5, 1e1, num=10, endpoint=True)

    def create_experiment(seed, gamma, regularizer):
        return Experiment(
            seed=seed,
            gamma=gamma,
            num_samples=1,
            train_batch_size=128,
            val_batch_size=512,
            epochs=150,
            log_interval=10,
            regularizer=regularizer,
            latent_layer_name="0",
            capacity=256,
        )

    all_experiments = (
        [create_experiment(770047, gamma, "mean_squared_Z") for gamma in gammas_Z2]
        + [create_experiment(804859, gamma, "entropy_via_variance_Z__Y") for gamma in gammas_ZY]
        + [create_experiment(694773, gamma, "entropy_via_variance_Z") for gamma in gammas_Z]
    )
    all_experiments.append(dataclasses.replace(all_experiments[-1], seed=215377, gamma=0))

    for job_id, store in embedded_experiments(__file__, len(all_experiments)):
        experiment = all_experiments[job_id]
        experiment.seed += job_id
        print(experiment)
        store["experiment"] = experiment
        store["log"] = {}

        try:
            experiment.run(store=store["log"])
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
