import XXX.notebook

import enum
import traceback
from dataclasses import dataclass

import torch

import experiments.datasets.cifar10 as dataset_cifar10
import experiments.models.cifar10 as model_cifar10
from XXX.uib import categorical_iq_loss
from XXX.uib.modules.categorical_decoder import CategoricalDecoder
from XXX.uib.modules.encoder_decoder import CategoricalEncoderDecoder
from experiments.dynamics.dynamics import CategoricalEncoderEncodingLossDynamics
from experiments.utils.experiment_YYY import embedded_experiments
from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput


@dataclass
class Experiment:
    seed: int
    capacity: int
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int
    out_capacity: int = 10

    def run(self, store):
        torch.manual_seed(self.seed)
        dataloaders = dataset_cifar10.dataloaders(self.train_batch_size, self.val_batch_size, augmentation=True, train_only=False)

        model = CategoricalEncoderDecoder(
            model_cifar10.DeterministicCifar10Resnet(capacity=self.capacity), CategoricalDecoder(self.capacity, 10)
        )

        model.encoder.cuda()
        model.decoder.cuda()

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        dynamics_instance = CategoricalEncoderEncodingLossDynamics(
            model, optimizer, categorical_iq_loss.iq_loss(categorical_iq_loss.iq.decoder_uncertainty)
        )

        scheduler = ReduceLROnPlateauWrapper(
            dynamics_instance.optimizer,
            output_transform=IgniteOutput.get_loss,
            mode="min",
            patience=5,
            factor=0.1 ** 0.5,
            threshold_mode="rel",
            verbose=True,
            min_lr=1e-7,
        )

        run_common_experiment(
            dynamics=dynamics_instance,
            dataloaders=dataloaders,
            max_epochs=self.epochs,
            in_capacity=self.capacity,
            out_capacity=self.out_capacity,
            train_lr_schedulers=scheduler,
            log_interval=self.log_interval,
            store=store,
            seed=self.seed,
            stochastic=False,
            discrete_summary=True,
            continuous_summary=False,
            device="cuda",
            minibatch_entropies=False,
            train_eval=True,
        )


if __name__ == "__main__":
    configs = [
        Experiment(
            seed=1238 + i,
            capacity=100,
            train_batch_size=128,
            val_batch_size=512,
            epochs=100,
            log_interval=10,
        )
        for i in range(8)
    ]

    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        print(config)
        store["config"] = config
        store["log"] = {}

        try:
            config.run(store=store["log"])
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
