""" Module for initializing and loading deep neural network variants with built-in performance
    logging and saving functionality
"""

from collections import OrderedDict
import enum
import os
from typing import Any, Dict
import yaml

from accelerate import Accelerator

from text2graph.models.base_model import BaseModel
from text2graph.models.grapher import Grapher
from text2graph.models.serialize import SerializedGraphGenerator


#pylint: disable=invalid-name
class ModelFactory(enum.Enum):
    """ NN model loading factory"""
    SerializedGraphGenerator = SerializedGraphGenerator
    Grapher = Grapher


def init_and_load_models(
    info_flow: Dict[str, Any],
    saving_directory: str,
    accelerator: Accelerator
) -> Dict[str, BaseModel]:
    """ Initializes an instance of each model class specified by the
        dictionary info_flow. Then loads the weights of the model if info_flow
        contains a directory to load that model from.
    """
    model_dict = OrderedDict()
    for model_name, model_config in info_flow.items():
        model_class = model_name.split("_")[0]
        metadata = model_config.get('metadata', {})
        metadata['name'] = model_name
        metadata['model_dir'] = saving_directory
        if model_config.get('model_dir') is not None:
            model_metadata_path = os.path.join(
                model_config['model_dir'],
                f"{model_name}_metadata.yaml"
            )
            assert os.path.isfile(model_metadata_path), f"{model_metadata_path} must be a file"
            with open(model_metadata_path, 'r', encoding="utf-8") as ymlfile:
                loaded_metadata = yaml.safe_load(ymlfile)
            metadata.update(loaded_metadata)
        model_dict[model_name] = ModelFactory[model_class].value(metadata=metadata)
        if model_config['model_dir'] is not None:
            model_dict[model_name].load_parameters(model_config['model_dir'], accelerator)
    return model_dict
