import traceback
from dataclasses import dataclass
from enum import Enum
from typing import List, Dict

# noinspection PyUnresolvedReferences
import XXX.notebook
import dataclasses
import numpy as np
import torch
from ignite.engine import Events

import experiments.datasets.imagenette as dataset_imagenette
from experiments.models import dropconnect_resnet_v2

import experiments.models.zero_entropy_noise
import experiments.utils.ignite_output
from XXX.uib.losses import cross_entropies, very_approx_regularizers
from experiments.dynamics import dynamics
from experiments.models import stochastic_model
from experiments.models.stochastic_model import StochasticModel
from experiments.utils.experiment_YYY import embedded_experiments
from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput

Regularizers = dict(
    mean_squared_Z=very_approx_regularizers.squared_sum(stochastic=True),
    entropy_via_variance_Z__Y=very_approx_regularizers.estimate_entropy_Z__Y(stochastic=True),
    entropy_via_variance_Z=very_approx_regularizers.estimate_entropy(
        very_approx_regularizers.covariance_trace, stochastic=True
    ),
)

class StochasticImagenetResnet(StochasticModel):
    def __init__(self, *, resnet_factory, num_samples, dropout_rate, fc_dropout_rate, pretrained=False, capacity=10):
        super().__init__(num_samples)

        self.resnet = resnet_factory(
            pretrained=pretrained, num_classes=capacity, dropout_rate=dropout_rate, fc_dropout_rate=fc_dropout_rate
        )

    def stochastic_forward_impl(self, x):
        x = self.resnet(x)
        # x = F.log_softmax(x, dim=1)

        return x


@dataclass
class Experiment:
    seed: int
    num_samples: int
    gamma: float
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int
    latent_layer_name: str
    regularizer: str
    capacity: int
    snapshot_every: int

    def run(self, store):
        gamma = self.gamma

        torch.backends.cudnn.benchmark = True

        torch.manual_seed(self.seed)
        dataloaders = dataset_imagenette.dataloaders(self.train_batch_size, self.val_batch_size)
        test_only_dataloaders = dataset_imagenette.dataloaders(self.train_batch_size, self.val_batch_size, test_only=True)

        resnet_model = StochasticImagenetResnet(
                    resnet_factory=dropconnect_resnet_v2.resnet18_v2,
                    capacity=self.capacity,
                    num_samples=self.num_samples,
                    dropout_rate=0.1,
                    fc_dropout_rate=0.1,
                )

        model = stochastic_model.as_stochastic_model(
            torch.nn.Sequential(
                resnet_model,
                torch.nn.Linear(self.capacity, 10),
            )
        )
        model.cuda()

        optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

        noise_injector = experiments.models.zero_entropy_noise.StochasticInjectZeroEntropyNoise(self.num_samples)
        noise_injector.cuda()

        loss_latent_extractor = dynamics.LatentExtractor(
            model.wrapped_model, layer_name=self.latent_layer_name, noise_injector=noise_injector
        )

        regularizer_func = Regularizers[self.regularizer]

        def full_stochastic_loss(latent_x_k_z, prediction_x_k_y, labels_x):
            loss = cross_entropies.stochastic_continuous_encoder_decoder_cross_entropy(prediction_x_k_y, labels_x)
            if gamma != 0.0:
                loss = loss + torch.nn.functional.relu(regularizer_func(latent_x_k_z, labels_x)) * gamma
            return loss

        dynamics_instance = dynamics.StochasticContinuousFullLossDynamics(
            model, optimizer, full_stochastic_loss, latent_extractor=loss_latent_extractor
        )

        scheduler = ReduceLROnPlateauWrapper(
            dynamics_instance.optimizer,
            output_transform=IgniteOutput.get_loss,
            mode="min",
            patience=10,
            factor=0.1 ** 0.5,
            threshold_mode="rel",
            verbose=True,
            min_lr=1e-7,
        )
        schedulers = [scheduler]

        def extra_hooks(trainer, evaluator, train_evaluator, validator):
            store["combined_evals"] = {}

            def combined_evaluator(engine):
                epoch = engine.state.epoch

                print(f"Combined eval {epoch}")

                for params in resnet_model.parameters():
                    params.requires_grad = False

                combined_model = stochastic_model.as_stochastic_model(
                    torch.nn.Sequential(
                        resnet_model,
                        torch.nn.Linear(self.capacity, 10),
                    )
                )
                combined_model.cuda()

                def decoder_loss(latent_x_k_z, prediction_x_k_y, labels_x):
                    loss = cross_entropies.stochastic_continuous_encoder_decoder_cross_entropy(prediction_x_k_y,
                                                                                               labels_x)
                    return loss

                decoder_optimizer = torch.optim.Adam(combined_model.wrapped_model[1].parameters(), lr=5e-4)
                combined_dynamics_instance = dynamics.StochasticContinuousFullLossDynamics(
                    combined_model, decoder_optimizer, decoder_loss, latent_extractor=loss_latent_extractor
                )

                decoder_scheduler = torch.optim.lr_scheduler.ExponentialLR(
                    combined_dynamics_instance.optimizer,
                    gamma=0.01 ** (1 / 12)
                )
                decoder_schedulers = [decoder_scheduler]

                store["combined_evals"][epoch] = {}
                epoch_log = store["combined_evals"][epoch]

                epoch_log["decoder_training"] = {}
                decoder_training_log = epoch_log["decoder_training"]

                # Train the decoder
                run_common_experiment(
                    dynamics=combined_dynamics_instance,
                    dataloaders=test_only_dataloaders,
                    max_epochs=12,
                    in_capacity=self.capacity,
                    out_capacity=10,
                    train_lr_schedulers=decoder_schedulers,
                    log_interval=self.log_interval,
                    store=decoder_training_log,
                    seed=self.seed,
                    stochastic=True,
                    discrete_summary=False,
                    continuous_summary=False,
                    device="cuda",
                    mean_l2_squared=False,
                    minibatch_entropies=False,
                    train_eval=False,
                    zero_eval=False,
                    test_eval=False
                )

                evaluation_log = {}
                # Evaluate the decoder
                run_common_experiment(
                    dynamics=combined_dynamics_instance,
                    dataloaders=test_only_dataloaders,
                    max_epochs=0,
                    in_capacity=self.capacity,
                    out_capacity=10,
                    train_lr_schedulers=decoder_schedulers,
                    log_interval=self.log_interval,
                    store=evaluation_log,
                    seed=self.seed,
                    stochastic=True,
                    discrete_summary=False,
                    continuous_summary=dict(
                        train_evaluator=dict(fixed_Z__X=0.0, k=1, num_splits=2),
                        evaluator=dict(fixed_Z__X=0.0, k=1, num_splits=1),
                        validator=dict(fixed_Z__X=0.0, k=1, num_splits=1),
                    ),
                    device="cuda",
                    mean_l2_squared=True,
                    minibatch_entropies=True,
                    train_eval=False,
                    zero_eval=True,
                )

                epoch_log["evaluation"] = evaluation_log['test_epochs'][0]

                for params in resnet_model.parameters():
                    params.requires_grad = True

            trainer.add_event_handler(Events.EPOCH_COMPLETED(every=self.snapshot_every), combined_evaluator)

            # Zeroth run
            combined_evaluator(trainer)

        run_common_experiment(
            dynamics=dynamics_instance,
            dataloaders=dataloaders,
            max_epochs=self.epochs,
            in_capacity=self.capacity,
            out_capacity=10,
            train_lr_schedulers=schedulers,
            log_interval=self.log_interval,
            store=store,
            seed=self.seed,
            stochastic=True,
            discrete_summary=False,
            continuous_summary=False,
            device="cuda",
            mean_l2_squared=False,
            minibatch_entropies=False,
            train_eval=False,
            zero_eval=False,
            test_eval=True,
            extra_hooks=extra_hooks
        )

        torch.save(model.state_dict(), f"imagenette_surrogate_{self.regularizer}_{self.gamma}_{self.seed}.pyt")


if __name__ == "__main__":
    gammas_Z2 = np.geomspace(1e-6, 1e1, num=16, endpoint=False)
    #gammas_Z = np.geomspace(1e-5, 1, num=16, endpoint=True)
    #gammas_ZY = np.geomspace(1e-5, 1e1, num=16, endpoint=True)

    def create_experiment(seed, gamma, regularizer):
        return Experiment(
            seed=seed,
            gamma=gamma,
            num_samples=1,
            train_batch_size=64,
            val_batch_size=256,
            epochs=100,
            log_interval=10,
            regularizer=regularizer,
            latent_layer_name="0",
            capacity=256,
            snapshot_every=10,
        )

    all_experiments = (
        [create_experiment(770047, gamma, "mean_squared_Z") for gamma in gammas_Z2]
        # + [create_experiment(804859, gamma, "entropy_via_variance_Z__Y") for gamma in gammas_ZY]
        # + [create_experiment(694773, gamma, "entropy_via_variance_Z") for gamma in gammas_Z]
    )
    all_experiments.append(dataclasses.replace(all_experiments[-1], seed=215377, gamma=0))

    for job_id, store in embedded_experiments(__file__, len(all_experiments)):
        experiment = all_experiments[job_id]
        experiment.seed += job_id
        print(experiment)
        store["experiment"] = experiment
        store["log"] = {}

        try:
            experiment.run(store=store["log"])
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
