import pandas as pd
from datetime import datetime

from benchmarking.utils import Experiment
from models.data_utils import get_dataloaders_for_tabular
from models.data_utils.util import get_device
from models import DeepCAE, StackedCAE, JointVAE, ConvAE, StandardAE, TransformerAE, PCA
from models import (
    DeepCAETrainer,
    StackedCAETrainer,
    JointVAETrainer,
    ConvAETrainer,
    StandardAETrainer,
    TransformerAETrainer,
    KernelPCATrainer,
    StackedCAETrainer
)

# from .utils import save_model


MODEL_TO_TRAINER_MAPPING = {
    DeepCAE: DeepCAETrainer,
    StackedCAE: StackedCAETrainer,
    JointVAE: JointVAETrainer,
    ConvAE: ConvAETrainer,
    StandardAE: StandardAETrainer,
    TransformerAE: TransformerAETrainer,
    PCA: KernelPCATrainer,
}


def train_model(
    model_class,
    dataset_name: str,
    config: dict,
    dim_reduction: float,
    device_id: int = None,
    tuning: bool = False,
    early_stopping: bool = True,
    patience: int = 15,
    min_delta: float = 0.995,
    checkpointing: bool = True,
) -> pd.Series:
    """
    Train a model on a given dataset and return the results as a pandas
    Series.
    The optional device ID is the ID for the CPU or GPU to run it on.
    """

    # Get the device
    device = get_device(device_id)

    # Get data first, before model size definition
    train_loader, test_loader = get_dataloaders_for_tabular(
        batch_size=64,
        path_to_data=f"artifacts/data/{dataset_name}/processed.csv",
        device=device,
        tuning=tuning,
    )
    input_dim = next(iter(train_loader)).size()[1]

    model_class_name = model_class.__name__

    if model_class == PCA and not tuning:
        # Special treatment for PCA
        hidden_dim = round(input_dim * dim_reduction)

        experiment = Experiment(
            {"hidden_dim": hidden_dim},
            model_class_name,
            f"Infer from date: {datetime.now()}",
            "",
            dataset_name,
        )
        model = model_class(hidden_dim)
        trainer_class = MODEL_TO_TRAINER_MAPPING[model_class]
        trainer_kwargs = {
            "model": model,
        }
        trainer = trainer_class(**trainer_kwargs)
        trainer.train(
            train_loader=train_loader,
            test_loader=test_loader,
            experiment=experiment,
        )

        del model
        del trainer
        # Now extract the final metrics from the experiment and return them
        return pd.DataFrame(experiment.flatten(), index=[0])
    else:
        if "models" in config:
            config = config["models"]
        if model_class_name in config:
            config = config[model_class_name]
        assert "epochs" in config, "Config has not attr epochs"

        # Init model
        hidden_spec = config.get("hidden_spec", None)
        if hidden_spec and len(hidden_spec) > 0:
            hidden_spec[-1] = round(input_dim * dim_reduction)
            if len(hidden_spec) == 2:
                # For the MultiLayer experiments comparing DeepCAE and StackedCAE.
                hidden_spec[-2] = round((input_dim + hidden_spec[-1]) / 2)

        model_kwargs = {
            "input_dim": input_dim,
        }

        if hidden_spec:
            model_kwargs["hidden_spec"] = hidden_spec
        else:
            # If there is no hidden spec, it has to be JointVAE and
            # there is a hidden_dim parameter instead.
            model_kwargs["hidden_dim"] = round(input_dim * dim_reduction)

        if channel_spec := config.get("channel_spec", None):
            model_kwargs["channel_spec"] = channel_spec

        if latent_spec := config.get("latent_spec", None):
            model_kwargs["latent_spec"] = latent_spec
        elif model_class == JointVAE:
            # Per default, we use no discrete variables.
            # We only use the same number of hidden continuous features as hidden_dim
            model_kwargs["latent_spec"] = {"cont": round(input_dim * dim_reduction)}

        model = model_class(**model_kwargs)
        model.to(device)

        # Define trainer
        trainer_class = MODEL_TO_TRAINER_MAPPING[model_class]
        trainer_kwargs = {
            "model": model,
            "print_loss_every": 50,
        }

        # In case of DeepCAE
        if lambda_c := config.get("lambda_c", None):
            trainer_kwargs["lambda_c"] = lambda_c

        # In case of JointVAE
        if cont_capacity := config.get("cont_capacity", None):
            trainer_kwargs["cont_capacity"] = cont_capacity

        # In case of JointVAE
        if disc_capacity := config.get("disc_capacity", None):
            trainer_kwargs["disc_capacity"] = disc_capacity

        trainer = trainer_class(**trainer_kwargs)

        experiment = Experiment(
            config,
            model_class_name,
            f"Infer from date: {datetime.now()}",
            model,
            dataset_name,
        )
        trainer.train(
            train_loader=train_loader,
            test_loader=test_loader,
            lr=config["lr"],
            epochs=config["epochs"],
            experiment=experiment,
            hyperparametertuning=tuning,
            device_id=device_id,
            early_stopping=early_stopping,
            patience=patience,
            min_delta=min_delta,
            checkpointing=checkpointing,
        )

        del model
        del trainer
        # Now extract the final metrics from the experiment and return them
        return pd.DataFrame(experiment.flatten(), index=[0])
