import XXX.notebook

import dataclasses
import enum
from dataclasses import dataclass
import traceback

import XXX.uib.losses.cross_entropies
from XXX.uib.modules.decoder_interface import PassthroughDecoder
from XXX.uib.modules.encoder_decoder import EncoderDecoder
from experiments.dynamics.dynamics import StochasticContinuousDynamics
from experiments.models import mnist

from experiments.utils.experiment_YYY import embedded_experiments
from experiments.datasets import fast_mnist, DataLoaders

import torch

from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput


class CrossEntropyType(enum.Enum):
    decoder = "decoder"
    prediction = "prediction"


@dataclass
class Config:
    seed: int
    capacity: int
    cross_entropy_type: CrossEntropyType
    num_samples: int
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int
    train_only: bool
    comment: str


def create_dynamics(seed, capacity, cross_entropy_type: CrossEntropyType, num_samples: int):
    torch.manual_seed(seed)

    # decoder = torch.nn.Sequential(torch.nn.Linear(capacity, capacity), torch.nn.ReLU(),
    #                               torch.nn.Linear(capacity, 10))

    decoder = torch.nn.Sequential(
        torch.nn.Linear(capacity, 10),
        # torch.nn.Softmax(dim=-1)
    )

    model = EncoderDecoder(mnist.StochasticDropoutModel(capacity, num_samples), PassthroughDecoder(decoder))

    model.encoder.cuda()
    model.decoder.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    if cross_entropy_type == CrossEntropyType.decoder:
        loss_function = XXX.uib.losses.cross_entropies.stochastic_continuous_encoder_decoder_cross_entropy
    elif cross_entropy_type == CrossEntropyType.prediction:
        loss_function = XXX.uib.losses.cross_entropies.stochastic_continuous_encoder_prediction_cross_entropy

    dynamics_instance = StochasticContinuousDynamics(model, optimizer, loss_function)

    return dynamics_instance


def run(
    seed,
    capacity,
    cross_entropy_type,
    num_samples,
    train_batch_size,
    val_batch_size,
    epochs,
    log_interval,
    train_only,
    store,
    comment,
):
    dataloaders = fast_mnist.dataloaders(train_batch_size, val_batch_size, train_only=train_only)

    dynamics = create_dynamics(seed, capacity, cross_entropy_type, num_samples)

    # scheduler = torch.optim.lr_scheduler.OneCycleLR(dynamics.optimizer, 5e-4, epochs=epochs,
    #                                                            steps_per_epoch=len(dataloaders.train))

    scheduler = ReduceLROnPlateauWrapper(
        dynamics.optimizer,
        output_transform=IgniteOutput.get_loss,
        mode="min",
        patience=3,
        factor=0.8,
        threshold_mode="abs",
        verbose=True,
    )

    in_capacity = capacity
    out_capacity = 10

    run_common_experiment(
        dynamics=dynamics,
        dataloaders=dataloaders,
        max_epochs=epochs,
        in_capacity=in_capacity,
        out_capacity=out_capacity,
        train_lr_schedulers=scheduler,
        log_interval=log_interval,
        store=store,
        seed=seed,
        stochastic=True,
        discrete_summary=False,
        continuous_summary=False,
    )


if __name__ == "__main__":
    # Only run with prediction now after we have run with decoderXE already.
    configs = [
        Config(
            seed=i + 56453,
            capacity=100,
            cross_entropy_type=cxt,
            num_samples=64,
            train_batch_size=16,
            val_batch_size=1024,
            epochs=100,
            log_interval=10,
            train_only=False,
            comment="",
        )
        for cxt in (CrossEntropyType.decoder, CrossEntropyType.prediction)
        for i in range(8)
    ]

    print(configs)

    for job_id, store in embedded_experiments(__file__, len(configs)):
        config = configs[job_id]
        print(config)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        try:
            run(store=store["log"], **dataclasses.asdict(config))
        except Exception:
            store["exception"] = traceback.format_exc()
            raise
