import dataclasses
import enum
from dataclasses import dataclass

import XXX.uib.losses.cross_entropies
from XXX.uib.modules.decoder_interface import PassthroughDecoder
from XXX.uib.modules.encoder_decoder import CategoricalEncoderDecoder
from experiments.dynamics.dynamics import CategoricalEncoderFullLossDynamics
from experiments.models import permutation_mnist

from experiments.utils.experiment_YYY import embedded_experiment
from experiments.datasets import fast_mnist

import torch

from experiments.utils.ignite_dynamics import run_common_experiment, ReduceLROnPlateauWrapper
from experiments.utils.ignite_output import IgniteOutput


class CrossEntropyType(enum.Enum):
    decoder = "decoder"
    prediction = "prediction"


@dataclass
class Config:
    seed: int
    capacity: int
    cross_entropy_type: CrossEntropyType
    train_batch_size: int
    val_batch_size: int
    epochs: int
    log_interval: int


def create_dynamics(seed, capacity, cross_entropy_type: CrossEntropyType):
    torch.manual_seed(seed)

    # decoder = torch.nn.Sequential(torch.nn.Linear(capacity, capacity), torch.nn.ReLU(),
    #                               torch.nn.Linear(capacity, 10))

    decoder = torch.nn.Sequential(
        torch.nn.Linear(capacity, 10),
        torch.nn.Softmax()
    )

    model = CategoricalEncoderDecoder(permutation_mnist.DropoutModel(capacity), PassthroughDecoder(decoder))

    model.encoder.cuda()
    model.decoder.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    if cross_entropy_type == CrossEntropyType.decoder:
        loss_function = (
            XXX.uib.losses.cross_entropies.deterministic_categorical_encoder_decoder_cross_entropy
        )
    elif cross_entropy_type == CrossEntropyType.prediction:
        loss_function = (
            XXX.uib.losses.cross_entropies.deterministic_categorical_encoder_prediction_cross_entropy
        )

    dynamics_instance = CategoricalEncoderFullLossDynamics(model, optimizer, loss_function)

    return dynamics_instance


def run(seed, capacity, cross_entropy_type, train_batch_size, val_batch_size, epochs, log_interval, store):
    dataloaders = fast_mnist.dataloaders(train_batch_size, val_batch_size)
    dynamics = create_dynamics(seed, capacity, cross_entropy_type)

    # scheduler = torch.optim.lr_scheduler.OneCycleLR(dynamics.optimizer, 5e-4, epochs=epochs,
    #                                                            steps_per_epoch=len(dataloaders.train))

    scheduler = ReduceLROnPlateauWrapper(
        dynamics.optimizer,
        output_transform=IgniteOutput.get_loss,
        mode="min",
        patience=3,
        factor=0.8,
        threshold_mode="abs",
        verbose=True,
    )

    in_capacity = capacity
    out_capacity = 10

    run_common_experiment(dynamics=dynamics, dataloaders=dataloaders, max_epochs=epochs, in_capacity=in_capacity,
                          out_capacity=out_capacity, train_lr_schedulers=scheduler, log_interval=log_interval,
                          store=store, seed=, stochastic=False, discrete_summary=True, continuous_summary=False)


if __name__ == "__main__":
    for i in range(5):
        config = Config(
            seed=i+145,
            capacity=100,
            cross_entropy_type=CrossEntropyType.decoder,
            train_batch_size=128,
            val_batch_size=5000,
            epochs=100,
            log_interval=10,
        )

        store = embedded_experiment(__file__)
        store["config"] = dataclasses.asdict(config)
        store["log"] = {}

        run(store=store["log"], **dataclasses.asdict(config))

#
# experiments.dynamics.mnist.train_mnist(experiments.dynamics.mnist.cel_mnist_dynamics(10, 10), 128, 2)
