from pathlib import Path

import hydra
from datasets import DatasetDict
from hydra import compose

from nn_core.common import PROJECT_ROOT
from nn_core.serialization import NNCheckpointIO, load_model

from rel2abs.aes.wandb_mapping import AES_MODELS
from rel2abs.aes.wandb_utils import get_run_dir, local_checkpoint_selection
from rel2abs.pl_modules.aes.pl_autoencoder import LightningAutoencoder

DATASET_DIR: Path = PROJECT_ROOT / "data" / "encoded_data" / "aes"


def load_datasets_from_cfg(dataset_name):
    cfg = compose(
        config_name="default",
        overrides=[
            "nn=aes",
            "train=reconstruction",
            f"nn/data/datasets=vision/{dataset_name}",
        ],
    )
    datamodule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
    datamodule.setup()
    train_dataset = datamodule.train_dataset
    val_datasets = datamodule.val_datasets
    return train_dataset, val_datasets[0]


def load_embeds_from_cfg(dataset_name) -> DatasetDict:
    return DatasetDict.load_from_disk(DATASET_DIR / dataset_name)


def load_model_from_wandb(dataset_name: str, model_name: str, decoder_norm: str, run_name: str):
    run_id = AES_MODELS[dataset_name][model_name][decoder_norm][run_name]
    rud_dir = get_run_dir(entity="...", project="rel2abs", run_id=run_id)
    ckpt_path = local_checkpoint_selection(rud_dir, 0)

    model = load_model(module_class=LightningAutoencoder, checkpoint_path=ckpt_path)
    cfg = NNCheckpointIO.load(path=ckpt_path, map_location="cpu")["cfg"]

    return model, cfg


def load_model_from_column_name(dataset_name: str, column_name: str):
    model_name, decoder_norm, run_name = column_name.split("_")
    return load_model_from_wandb(dataset_name, model_name, decoder_norm, run_name)
