import XXX.notebook

import dataclasses
import enum
from dataclasses import dataclass
import traceback

from XXX.uib.modules.categorical_decoder import CategoricalDecoder
from XXX.uib import categorical_iq_loss
from XXX.uib.modules.encoder_decoder import CategoricalEncoderDecoder
from experiments.dynamics.dynamics import CategoricalEncoderFullLossDynamics, CategoricalEncoderEncodingLossDynamics
from experiments.models import permutation_mnist

from experiments.utils.experiment_YYY import embedded_experiments
from experiments.datasets import fast_mnist

import torch

from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput


class CrossEntropyType(enum.Enum):
    decoder = "decoder"
    prediction = "prediction"


@dataclass
class Config:
    seed: int
    capacity: int
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int


def create_dynamics(seed, capacity):
    torch.manual_seed(seed)

    model = CategoricalEncoderDecoder(permutation_mnist.NoDropoutModel(capacity), CategoricalDecoder(capacity, 10))

    model.encoder.cuda()
    model.decoder.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    dynamics_instance = CategoricalEncoderEncodingLossDynamics(
        model, optimizer, categorical_iq_loss.iq_loss(categorical_iq_loss.iq.decoder_uncertainty)
    )

    return dynamics_instance


def run(seed, capacity, train_batch_size, val_batch_size, epochs, log_interval, store):
    dataloaders = fast_mnist.dataloaders(train_batch_size, val_batch_size, train_only=False)

    dynamics = create_dynamics(seed, capacity)

    # scheduler = torch.optim.lr_scheduler.OneCycleLR(dynamics.optimizer, 5e-4, epochs=epochs,
    #                                                            steps_per_epoch=len(dataloaders.train))

    scheduler = ReduceLROnPlateauWrapper(
        dynamics.optimizer,
        output_transform=IgniteOutput.get_loss,
        mode="min",
        patience=3,
        factor=0.8,
        threshold_mode="abs",
        verbose=True,
    )


    in_capacity = capacity
    out_capacity = 10

    run_common_experiment(
        dynamics=dynamics,
        dataloaders=dataloaders,
        max_epochs=epochs,
        in_capacity=in_capacity,
        out_capacity=out_capacity,
        train_lr_schedulers=scheduler,
        log_interval=log_interval,
        store=store,
        seed=seed,
        stochastic=False,
        discrete_summary=True,
        continuous_summary=False,
        train_eval=True
    )


if __name__ == "__main__":
    # Only run with prediction now after we have run with decoderXE already.
    configs = [
        Config(
            seed=6546 + i,
            capacity=100,
            train_batch_size=128,
            val_batch_size=5000,
            epochs=100,
            log_interval=10,
        )
        for i in range(8)
    ]

    print(configs)

    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            run(store=store["log"], **dataclasses.asdict(config))
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
