import dataclasses
import traceback
from dataclasses import dataclass
from pprint import pprint

# noinspection PyUnresolvedReferences
import XXX.notebook
import torch
from ignite.engine import Events

from XXX.uib.losses import cross_entropies
from XXX.uib.modules.l2_and_variance_summarizer import L2AndVarianceSummarizer
from experiments.datasets import fast_mnist
from experiments.dynamics.dynamics import LatentExtractor, StochasticContinuousFullLossDynamics
from experiments.models.zero_entropy_noise import _inject_zero_entropy_noise, StochasticInjectZeroEntropyNoise
from experiments.models import mnist
from experiments.models.stochastic_model import StochasticModelWrapper
from experiments.utils.experiment_YYY import embedded_experiments
from experiments.utils.ignite_dynamics import run_common_experiment, install_common_summarizer_iqs


@dataclass
class Config:
    seed: int
    inject_noise: bool
    capacity: int
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int


def run(seed, inject_noise, capacity, train_batch_size, val_batch_size, epochs, log_interval, store):
    torch.manual_seed(seed)

    decoder = torch.nn.Sequential(torch.nn.Linear(capacity, 10))

    current_epoch = 1

    base = 1.5

    def scale_input(input):
        return input / torch.std(input, dim=(0, 1)) * base ** (-current_epoch + 1)

    zero_entropy_noise = StochasticInjectZeroEntropyNoise(12)

    def inject_noise_and_scale_input(input):
        return zero_entropy_noise(scale_input(input))

    if not inject_noise:
        model = StochasticModelWrapper(torch.nn.Sequential(mnist.StochasticDropoutModel(capacity, 12), decoder), 1)
        latent_extractor = LatentExtractor(model, layer_name="wrapped_model.0.fc2", noise_injector=scale_input)
    else:
        model = StochasticModelWrapper(torch.nn.Sequential(mnist.StochasticDropoutModel(capacity, 12), decoder), 1)
        latent_extractor = LatentExtractor(
            model, layer_name="wrapped_model.0.fc2", noise_injector=inject_noise_and_scale_input
        )

    # model.double()
    model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    def full_stochastic_loss(latent_x_k_z, prediction_x_k_y, labels_x):
        loss = cross_entropies.stochastic_continuous_encoder_decoder_cross_entropy(prediction_x_k_y, labels_x)

        # print(loss)
        return loss

    dynamics_instance = StochasticContinuousFullLossDynamics(
        model, optimizer, full_stochastic_loss, latent_extractor=latent_extractor
    )

    dataloaders = fast_mnist.dataloaders(train_batch_size, val_batch_size)

    in_capacity = capacity
    out_capacity = 10

    def extra_hooks(trainer, evaluator):
        # TODO: add code to scale the weights
        @trainer.on(Events.EPOCH_STARTED)
        def update_current_epoch(_):
            nonlocal current_epoch
            current_epoch = trainer.state.epoch

            with torch.no_grad():
                model.wrapped_model[0].fc2.weight.data *= base

        summarizer = L2AndVarianceSummarizer(in_capacity, out_capacity)
        install_common_summarizer_iqs(summarizer, evaluator, stochastic=True)

    run_common_experiment(
        dynamics=dynamics_instance,
        dataloaders=dataloaders,
        max_epochs=epochs,
        in_capacity=in_capacity,
        out_capacity=out_capacity,
        train_lr_schedulers=None,
        log_interval=log_interval,
        store=store,
        seed=seed,
        stochastic=True,
        discrete_summary=False,
        continuous_summary=False,
        device="cuda",
        minibatch_entropies=False,
        extra_hooks=extra_hooks,
    )


if __name__ == "__main__":
    configs = [
        Config(
            seed=10 + i,
            inject_noise=inject_noise,
            capacity=128,
            train_batch_size=128,
            val_batch_size=512,
            epochs=20,
            log_interval=10,
        )
        for i in range(4)
        for inject_noise in (False, True)
    ]

    pprint(configs)

    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            run(store=store["log"], **dataclasses.asdict(config))
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
