import XXX.notebook

from experiments.datasets import DataLoaders
from experiments.dynamics import dynamics

# Bc we are reloading results to find the snapshots
from experiments.utils.jupyter import results_loader

import dataclasses
import traceback
from dataclasses import dataclass

from XXX.uib.losses import cross_entropies

import experiments.datasets.cifar10 as dataset_cifar10
from experiments.models.zero_entropy_noise import StochasticInjectZeroEntropyNoise

from experiments.utils.experiment_YYY import embedded_experiments

import torch

import experiments.runs.iclr_experiments.cifar10_no_dropout_surrogates_training as base_experiment

from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput


@dataclass
class Config:
    seed: int
    batch_size: int
    job_id: int

    def run(self, store):
        torch.backends.cudnn.benchmark = True
        torch.manual_seed(self.seed)

        dataloaders = dataset_cifar10.dataloaders(
            self.batch_size,
            self.batch_size,
            train_only=False,
            test_only=True,
            augmentation=True,
            normalize=True,
            validation_size=0
        )

        # Load results
        loaded_results = results_loader.load_YYY_files('src/experiments/runs/iclr_experiments/results_no_dropout')
        results = results_loader.filter_dict(loaded_results, v=lambda result: result.job_id == self.job_id)

        assert len(results) == 1

        result = results_loader.get_any(results)
        config = base_experiment.Experiment(**result.experiment._asdict())

        model = config.create_model()

        if config.inject_noise:
            noise_injector = StochasticInjectZeroEntropyNoise(1)
            noise_injector.cuda()
        else:
            noise_injector = None

        loss_latent_extractor = dynamics.LatentExtractor(
            model.wrapped_model, layer_name=config.latent_layer_name, noise_injector=noise_injector
        )

        regularizer_func = base_experiment.Regularizers[config.regularizer]

        def full_stochastic_loss(latent_x_k_z, prediction_x_k_y, labels_x):
            loss = cross_entropies.stochastic_continuous_encoder_decoder_cross_entropy(prediction_x_k_y, labels_x)
            if config.gamma != 0.0:
                loss = loss + regularizer_func(latent_x_k_z, labels_x) * config.gamma
            return loss

        import pprint
        pprint.pprint(config)

        store["epochs"] = {}

        for epoch, snapshot_name in result.log.snapshots.items():
            print(f"--'Epoch {epoch}--")
            print(f"Loading snapshot {snapshot_name}")

            model.load_state_dict(torch.load(snapshot_name))
            model.cuda()

            # NOTE: maybe this only needs to be done once, but I'm not taking risks.
            for params in model.wrapped_model[0].parameters():
                params.requires_grad = False

            optimizer = torch.optim.Adam(model.wrapped_model[1].parameters(), lr=5e-4, weight_decay=config.weight_decay)
            dynamics_instance = dynamics.StochasticContinuousFullLossDynamics(
                model, optimizer, full_stochastic_loss, latent_extractor=loss_latent_extractor
            )

            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                dynamics_instance.optimizer,
                gamma=0.01 ** (1 / 25)
            )
            schedulers = [scheduler]

            store["epochs"][epoch] = {}
            epoch_log = store["epochs"][epoch]

            epoch_log["decoder_training"] = {}
            decoder_training_log = epoch_log["decoder_training"]

            # Train the decoder
            run_common_experiment(
                dynamics=dynamics_instance,
                dataloaders=dataloaders,
                max_epochs=25,
                in_capacity=config.capacity,
                out_capacity=10,
                train_lr_schedulers=schedulers,
                log_interval=config.log_interval,
                store=decoder_training_log,
                seed=config.seed,
                stochastic=True,
                discrete_summary=False,
                continuous_summary=False,
                device="cuda",
                mean_l2_squared=False,
                minibatch_entropies=False,
                train_eval=False,
                zero_eval=False,
                test_eval=False
            )

            evaluation_log = {}
            # Evaluate the decoder
            run_common_experiment(
                dynamics=dynamics_instance,
                dataloaders=dataloaders,
                max_epochs=0,
                in_capacity=config.capacity,
                out_capacity=10,
                train_lr_schedulers=schedulers,
                log_interval=config.log_interval,
                store=evaluation_log,
                seed=config.seed,
                stochastic=True,
                discrete_summary=False,
                continuous_summary=dict(
                    train_evaluator=dict(fixed_Z__X=0.0, k=1, num_splits=1),
                    evaluator=dict(fixed_Z__X=0.0, k=1, num_splits=1),
                    validator=dict(fixed_Z__X=0.0, k=1, num_splits=1),
                ),
                device="cuda",
                mean_l2_squared=True,
                minibatch_entropies=True,
                train_eval=False,
                zero_eval=True,
            )

            epoch_log["evaluation"] = evaluation_log['test_epochs'][0]


configs = [
    Config(
        seed=458326 + job_id * 31,
        job_id=job_id,
        batch_size=128,
    )
    for job_id in range(82)
]

if __name__ == "__main__":
    # import pprint

    # pprint.pprint(configs)

    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            config.run(store=store["log"])
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
