import pickle
from typing import List, Tuple, Union

import pytorch_lightning as pl
import torch

from src.chemutils.constants import INFERENCE_HPS

from src.data.mol_module import MolDataModule

from src.model.flow_vae import FlowMAGNet
from src.model.vae import MAGNet
from src.utils import WB_LOG_PATH


def load_model_from_id(
    collection: str,
    run_id: str,
    dataset: str = "zinc",
    seed_model: int = 0,
    model_class: object = MAGNet,
    load_config: dict = dict(),
    return_config: bool = False,
) -> Union[MAGNet, FlowMAGNet]:
    """
    Load model from collection / id and set inference parameters
    """
    pl.utilities.seed.seed_everything(seed_model)
    dm = MolDataModule(dataset, batch_size=1, num_workers=6, shuffle=False)
    dm.setup()

    checkpoint_dir = WB_LOG_PATH / collection / run_id / "checkpoints"
    with open(checkpoint_dir / "load_config.pkl", "rb") as file:
        config = pickle.load(file)

    state_dict = torch.load(checkpoint_dir / "last.ckpt")["state_dict"]

    for key in load_config.keys():
        if key in config.keys():
            config.pop(key)

    load_sepearate_modules = False
    if model_class == FlowMAGNet:
        if any("flow_model" in key for key in state_dict.keys()):
            print("Found Flow Modules in state dict, loading full FlowMAGNet model.")
            load_config["load_flow_modules"] = True
        else:
            print("Loading MAGNet model, initializing new Flow Modules!")
            load_config["load_flow_modules"] = False
            load_sepearate_modules = True

    model = model_class(feature_sizes=dm.feature_sizes, **config, **load_config)
    model.load_state_dict(state_dict)

    if load_sepearate_modules:
        print("Initializing Flow Modules...")
        model.initialize_flow_modules()

    model.cuda()
    model.eval()

    model.trainer = pl.Trainer()
    model.trainer.datamodule = dm

    sampling_params = INFERENCE_HPS[dataset.lower()]
    for key, value in sampling_params.items():
        setattr(model, key, value)
    if return_config:
        config.update(load_config)
        return model, config
    return model
