import XXX.notebook

import dataclasses
import enum
import traceback
from dataclasses import dataclass

from XXX.uib.losses import cross_entropies

import experiments.models.cifar10 as model_cifar10
import experiments.datasets.cifar10 as dataset_cifar10
from experiments.models import dropconnect_resnet_v2

from XXX.uib.modules.decoder_interface import PassthroughDecoder
from XXX.uib.modules.encoder_decoder import EncoderDecoder
from experiments.dynamics.dynamics import StochasticContinuousDynamics

from experiments.models import stochastic_model

from experiments.utils.experiment_YYY import embedded_experiment, embedded_experiments

import torch

from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput


class CrossEntropyType(enum.Enum):
    decoder = "decoder"
    prediction = "prediction"


@dataclass
class Config:
    seed: int
    capacity: int
    num_samples: int
    cross_entropy_type: CrossEntropyType
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int
    train_only: bool


def run(
    seed,
    capacity,
    num_samples,
    cross_entropy_type,
    train_batch_size,
    val_batch_size,
    epochs,
    log_interval,
    store,
    train_only,
):
    out_capacity = 10

    torch.manual_seed(seed)
    dataloaders = dataset_cifar10.dataloaders(
        train_batch_size, val_batch_size, train_only=train_only, augmentation=True
    )

    model = stochastic_model.as_stochastic_model(
        torch.nn.Sequential(
            model_cifar10.StochasticCifar10Resnet(
                resnet_factory=dropconnect_resnet_v2.resnet18_v2,
                capacity=capacity,
                num_samples=num_samples,
                dropout_rate=0.1,
                fc_dropout_rate=0.1,
            ),
            torch.nn.Linear(capacity, 10),
        )
    )
    model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

    if cross_entropy_type == CrossEntropyType.decoder:
        loss_function = XXX.uib.losses.cross_entropies.stochastic_continuous_encoder_decoder_cross_entropy
    elif cross_entropy_type == CrossEntropyType.prediction:
        loss_function = XXX.uib.losses.cross_entropies.stochastic_continuous_encoder_prediction_cross_entropy

    dynamics_instance = StochasticContinuousDynamics(model, optimizer, loss_function)

    # scheduler = torch.optim.lr_scheduler.OneCycleLR(dynamics.optimizer, 1e-4, epochs=epochs,
    #                                                            steps_per_epoch=len(dataloaders.train))

    scheduler = ReduceLROnPlateauWrapper(
        dynamics_instance.optimizer,
        output_transform=IgniteOutput.get_loss,
        mode="min",
        patience=5,
        factor=0.1 ** 0.5,
        threshold_mode="rel",
        verbose=True,
        min_lr=5e-5,
    )

    run_common_experiment(
        dynamics=dynamics_instance,
        dataloaders=dataloaders,
        max_epochs=epochs,
        in_capacity=capacity,
        out_capacity=out_capacity,
        train_lr_schedulers=scheduler,
        log_interval=log_interval,
        store=store,
        seed=seed,
        stochastic=True,
        discrete_summary=False,
        continuous_summary=False,
        device="cuda",
        minibatch_entropies=False,
    )


if __name__ == "__main__":
    configs = [
        Config(
            seed=646 + i,
            num_samples=8,
            capacity=100,
            cross_entropy_type=cet,
            train_batch_size=32,
            val_batch_size=256,
            epochs=100,
            log_interval=10,
            train_only=False,
        )
        for cet in (CrossEntropyType.decoder, CrossEntropyType.prediction)
        for i in range(8)
    ]

    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            run(store=store["log"], **dataclasses.asdict(config))
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
