import copy
import os
import pickle
from enum import Enum
import logging
import torch
from threading import RLock

from train import Config
from train.common.utils import import_module
from train.package_manager import PackageManager


class ModelStatus(Enum):
    Untrained = 0
    GeneralBehaviourCloned = 1
    TaskBehaviourCloned = 2
    FineTuned = 3


class ModelRegistry:
    """
    Singleton ModelRegistry instance to handle model pre-training for general agents.
    """
    __instance = None
    __mutex = RLock()

    def __init__(self, config):
        """ Virtually private constructor. """
        if ModelRegistry.__instance is not None:
            raise Exception("This class is a singleton!")
        self.config = config
        self.registry = {}
        self.model_ckpt = 'model.ckpt'
        self.meta_file = 'meta.ckpt'
        self.unsaved_changes = True
        logging.info('ModelRegistry: Created new registry.')

    @staticmethod
    def get_instance(config: Config = None):
        """
        Returns singleton of the ModelRegistry class.
        :param config:
        :return:
        """
        ModelRegistry.__mutex.acquire()
        try:
            if ModelRegistry.__instance is None:
                ModelRegistry.__instance = ModelRegistry(config)
            return ModelRegistry.__instance
        finally:
            ModelRegistry.__mutex.release()

    def reset(self):
        """
        Resets the model registry to empty the dictionary of all registered models.
        :return:
        """
        ModelRegistry.__mutex.acquire()
        try:
            self.registry = {}
            self.unsaved_changes = True
            logging.info('ModelRegistry: Models reset!')
        finally:
            ModelRegistry.__mutex.release()

    def register(self, target):
        """
        Register a new target to the instance. Creates the target model if it not already exists.
        Otherwise the existing target model is loaded.
        :param target: Target id for the registry.
        :return:
        """
        ModelRegistry.__mutex.acquire()
        try:
            status = ModelStatus.Untrained
            logging.info('ModelRegistry: Model not ready. Instantiate new model.')
            if PackageManager.get_instance().enabled():
                net = PackageManager.get_instance().model.Network
            else:
                net = import_module(self.config.model.model).Network
            model = net()
            model_dir = os.path.join(self.config.checkpoint_path, target)
            model_file = os.path.join(model_dir, self.model_ckpt)
            if os.path.exists(model_file):
                # always load the models first to cpu
                state_dict = model.state_dict()
                state_dict.update(torch.load(model_file, map_location='cpu'))
                model.load_state_dict(state_dict, strict=False)
                # load the model meta information
                meta_file = os.path.join(model_dir, self.meta_file)
                if os.path.exists(meta_file):
                    with open(meta_file, 'rb') as ft:
                        meta_dict = pickle.load(ft)
                    assert target == meta_dict['target']
                    status = ModelStatus(meta_dict['status'])
                logging.info('ModelRegistry: Previous best model file found. Loading weights...')
            else:
                logging.info('ModelRegistry: No previous best model file found. Saving weights...')
                self.save(target)
            logging.info('ModelRegistry: Model ready!')
            self.registry[target] = (model, status)
            self.unsaved_changes = True
            return self.registry[target]
        finally:
            ModelRegistry.__mutex.release()

    def get(self, target):
        """
        Returns or registers a new model for the given target.
        :param target: Target id for the registry.
        :return:
        """
        ModelRegistry.__mutex.acquire()
        try:
            if target in self.registry:
                # copy to assure non-modified models
                logging.info('ModelRegistry: Preparing model.')
                model, status = self.registry[target]
                return copy.deepcopy(model), status
            else:
                return self.register(target)
        finally:
            ModelRegistry.__mutex.release()

    def update(self, target, model, status):
        """
        Updates the target with the given model and status.
        :param target: Target name for the registry id.
        :param model: Model to update.
        :param status: Status to update.
        :return:
        """
        ModelRegistry.__mutex.acquire()
        try:
            assert target in self.registry
            logging.info('ModelRegistry: Updating model.')
            self.registry[target] = (model, status)
            self.unsaved_changes = True
        finally:
            ModelRegistry.__mutex.release()

    def save(self, target):
        """
        Saves the given target object if available in the registry.
        :param target: Target id for the registry.
        :return:
        """
        ModelRegistry.__mutex.acquire()
        try:
            if target in self.registry:
                model, status = self.registry[target]

                dict_obj = {
                    'target': target,
                    'status': status
                }

                # create the dictionary to the task transition
                model_path = os.path.join(self.config.checkpoint_path, target)
                if not os.path.exists(model_path):
                    logging.info("ModelRegistry: Creating checkpoint path for {} model...".format(target))
                    os.makedirs(model_path)

                # save the task information
                meta_file = os.path.join(model_path, self.meta_file)
                logging.info("ModelRegistry: Saving {} model meta object...".format(target))
                with open(meta_file, 'wb') as f:
                    pickle.dump(dict_obj, f)

                # save model
                model_file = os.path.join(model_path, target)
                logging.info('ModelRegistry: Saving model...')
                torch.save(model.state_dict(), model_file)

                self.unsaved_changes = False
                logging.info('ModelRegistry: Saving complete.')
        finally:
            ModelRegistry.__mutex.release()
