import XXX.notebook

import dataclasses
import enum
import traceback
from dataclasses import dataclass

from XXX.uib.losses import cross_entropies

import experiments.models.cifar10 as model_cifar10
import experiments.datasets.cifar10 as dataset_cifar10
from XXX.uib.modules.decoder_interface import PassthroughDecoder
from XXX.uib.modules.encoder_decoder import CategoricalEncoderDecoder

from experiments.dynamics.dynamics import StochasticContinuousDynamics, \
    CategoricalEncoderFullLossDynamics
from experiments.models.zero_entropy_noise import _inject_zero_entropy_noise
from experiments.models.stochastic_model import StochasticModelWrapper

from experiments.utils.experiment_YYY import embedded_experiment, embedded_experiments

import torch

from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput


class CrossEntropyType(enum.Enum):
    decoder = "decoder"
    prediction = "prediction"


@dataclass
class Config:
    seed: int
    capacity: int
    cross_entropy_type: CrossEntropyType
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int


def run(seed, capacity, cross_entropy_type, train_batch_size, val_batch_size, epochs, log_interval, store):
    out_capacity = 10

    torch.manual_seed(seed)
    dataloaders = dataset_cifar10.dataloaders(train_batch_size, val_batch_size, train_only=False, augmentation=True)

    decoder = torch.nn.Sequential(
        torch.nn.Linear(capacity, 1000),
        torch.nn.ReLU(),
        torch.nn.Linear(1000, 1000),
        torch.nn.ReLU(),
        torch.nn.Linear(1000, 10),
        torch.nn.Softmax(dim=-1)
    )

    model = CategoricalEncoderDecoder(model_cifar10.DeterministicCifar10Resnet(capacity=capacity), PassthroughDecoder(decoder))

    model.encoder.cuda()
    model.decoder.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

    if cross_entropy_type == CrossEntropyType.decoder:
        loss_function = (
            cross_entropies.deterministic_categorical_encoder_decoder_cross_entropy
        )
    elif cross_entropy_type == CrossEntropyType.prediction:
        loss_function = (
            cross_entropies.deterministic_categorical_encoder_prediction_cross_entropy
        )

    dynamics_instance = CategoricalEncoderFullLossDynamics(model, optimizer, loss_function)

    # scheduler = torch.optim.lr_scheduler.OneCycleLR(dynamics.optimizer, 1e-4, epochs=epochs,
    #                                                            steps_per_epoch=len(dataloaders.train))

    scheduler = ReduceLROnPlateauWrapper(
        dynamics_instance.optimizer,
        output_transform=IgniteOutput.get_loss,
        mode="min",
        patience=10,
        factor=0.1 ** 0.5,
        threshold_mode="rel",
        verbose=True,
        min_lr=1e-7,
    )

    run_common_experiment(dynamics=dynamics_instance, dataloaders=dataloaders, max_epochs=epochs, in_capacity=capacity,
                          out_capacity=out_capacity, train_lr_schedulers=scheduler, log_interval=log_interval,
                          store=store, seed=seed, stochastic=False, discrete_summary=True, continuous_summary=False,
                          device="cuda", minibatch_entropies=False, train_eval=True)


if __name__ == "__main__":
    configs = [
        Config(
            seed=646 + i,
            capacity=100,
            cross_entropy_type=cet,
            train_batch_size=128,
            val_batch_size=512,
            epochs=100,
            log_interval=10,
        )
        for cet in (CrossEntropyType.decoder, CrossEntropyType.prediction)
        for i in range(8)
    ]

    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            run(store=store["log"], **dataclasses.asdict(config))
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
