import copy
import os
import pickle
from enum import Enum
import logging
import torch
import numpy as np
from threading import RLock
import json

from train.package_manager import PackageManager
from train.subtask_identifier.data_handling import from_data_to_seq_actions
from train.model_registry import ModelRegistry, ModelStatus


class TaskType(Enum):
    Learning = 0
    Imitation = 1


class Task:
    """
    Task specific object containing all necessary information to solve a Minecraft task-transition
    built through the consensus.
    """
    def __init__(self, config, id: str, target: str, task_type: TaskType):
        # make object thread save
        self.mutex = RLock()

        # meta controller config file
        self.config = config

        # task properties
        self.id = id
        self.target = target
        self.median_task_steps = np.inf

        self.data_dir_target = None
        self.data_dir_transition = None

        self.task_type = task_type
        # data_id represents the environment id
        self.data_id = None
        self.task_file = 'task.ckpt'

        # data flags
        self.target_data_ready = False
        self.transition_data_ready = False

        # aux. flags for task checking
        self.model_ready = False
        self.imitation_ready = False
        self.imitation_fail_cnt = 0
        self.unsaved_changes = True
        self.print_warning = True

        # model parameter for collecting tasks
        # only relevant for TaskType.Learning
        self.model_ckpt = 'model.ckpt'
        self._model = None
        self.model_status = ModelStatus.Untrained
        self.eval_hook = None
        self.model_registry: ModelRegistry = ModelRegistry.get_instance(self.config)

        # imitation action sequence
        # only relevant for TaskType.Imitation
        self.imitation_ckpt = 'metainfo.json'
        self._imitation_commands = None

        logging.info("Task: Instantiating new task {}".format(id))

    @staticmethod
    def clone_list(task_list):
        """
        Deep copy a list of tasks.
        :param task_list:
        :return:
        """
        assert len(task_list) > 0
        new_task_list = []
        for task in task_list:
            new_task = Task.clone(task)
            new_task_list.append(new_task)
        return new_task_list

    @staticmethod
    def clone(task):
        """
        Deep copy a task object.
        :param task:
        :return:
        """
        new_task = Task(task.config, task.id, task.target, task.task_type)
        new_task.data_dir_target = task.data_dir_target
        new_task.data_dir_transition = task.data_dir_transition
        new_task.data_id = task.data_id
        new_task.task_file = task.task_file
        new_task.target_data_ready = task.target_data_ready
        new_task.transition_data_ready = task.transition_data_ready
        new_task.model_ready = task.model_ready
        new_task.imitation_ready = task.imitation_ready
        new_task.unsaved_changes = task.unsaved_changes
        new_task.model_ckpt = task.model_ckpt
        new_task._model = copy.deepcopy(task._model)
        new_task.model_status = task.model_status
        new_task.eval_hook = task.eval_hook
        new_task.imitation_ckpt = task.imitation_ckpt
        new_task.median_task_steps = task.median_task_steps
        new_task._imitation_commands = copy.deepcopy(task._imitation_commands)

        logging.info("Task: Cloned task {}".format(task.id))
        return new_task

    def get_model(self):
        """
        Returns the model if preloaded and sets the unsaved_changes flag.
        :return:
        """
        assert self.task_type == TaskType.Learning
        self.mutex.acquire()
        try:
            if not self.model_ready:
                logging.info('Task: Model not ready. Loading model instance.')
                # prepare package manager to setup environment
                pm = PackageManager.get_instance()
                if pm.enabled():
                    if self.target in self.config.package_task_profile:
                        pm.switch(self.config.package_task_profile[self.target])
                    else:
                        pm.default()
                # get general model
                self._model, self.model_status = self.model_registry.get(self.target)
                model_dir = os.path.join(self.config.checkpoint_path, self.id)
                model_file = os.path.join(model_dir, self.model_ckpt)
                if os.path.exists(model_file):
                    logging.info('Task: Previous best model file found. Loading weights...')
                    # always load the models first to cpu
                    state_dict = self._model.state_dict()
                    state_dict.update(torch.load(model_file, map_location='cpu'))
                    self._model.load_state_dict(state_dict, strict=False)
                else:
                    self.save()
                self.model_ready = True
                self.unsaved_changes = True
                logging.info('Task: Model ready!')
            # copy due to multi-processing
            return copy.deepcopy(self._model)
        finally:
            self.mutex.release()

    def notify_transition(self, next_task):
        """
        Notify task transitions to handle model unloading.
        :param next_task: The successor of the current task.
        :return:
        """
        self.mutex.acquire()
        try:
            # check if we are still operating with the same object and cleanup if not
            if self.task_type == TaskType.Learning and next_task is not self:
                self._unload_model()
                logging.info('Task: Transition encountered! Unload previous model.')
        finally:
            self.mutex.release()

    def update_model(self, model, model_status):
        """
        Overrides the current model.
        :param model: Model to change.
        :param model_status: The current model status.
        :return:
        """
        self.mutex.acquire()
        self._model = model
        self.model_status = model_status
        self.unsaved_changes = True
        logging.info('Task: Updated model weights.')
        self.mutex.release()

    def _unload_model(self):
        """
        Clears the parameters from the gpu.
        :return:
        """
        assert self.task_type == TaskType.Learning
        if self.model_ready:
            logging.info('Task: Moving model to cpu to free cuda memory.')
            self._model = self._model.cpu()
        self.model_ready = False
        logging.info('Task: Updated model weights.')

    def _select_imitation_file(self, current_state):
        """
        Select the imitation metainfo to choose a action sequence from human demonstrations.
        :param current_state:
        :return:
        """
        observation = current_state['inventory']
        with open(os.path.join(self.data_dir_transition, self.id, self.imitation_ckpt), 'r') as o_f:
            list_files = json.load(fp=o_f)

        # Sort sequences by diff
        example_list = sorted(list_files, key=lambda k: k['n_act'])
        for i, f in enumerate(example_list):

            if f['n_act'] > 0:
                f = example_list[i+self.imitation_fail_cnt % len(example_list)]
                imitation_file = os.path.join(self.data_dir_transition, self.id, self.data_id,
                                              # TODO: remove fix for wrong path definition
                                              f['path_to_folder'].replace('/', ''), 'rendered.npz')

        if not os.path.exists(imitation_file):
            raise FileNotFoundError("File does not exist: {}".format(imitation_file))

        return imitation_file

    def _load_imitation_actions(self, current_state):
        """
        Loads the sequence of actions from one particular sequence
        :return:
        """
        if self.data_dir_transition is None or not os.path.exists(self.data_dir_transition):
            logging.error("Path does not exists {}".format(self.data_dir_transition))
            return

        selected_file = self._select_imitation_file(current_state)
        logging.info("Task: Loading imitation file {}".format(selected_file))

        data = np.load(selected_file)
        self._imitation_commands = from_data_to_seq_actions(data)
        self.imitation_ready = True

    def get_imitation_actions(self, current_state):
        """
        Returns imitation actions.
        :return:
        """
        assert self.task_type == TaskType.Imitation
        self.mutex.acquire()
        try:
            # If not actions loaded (first call to the Task), load the actions
            if not self.imitation_ready:
                # Load sequence of actions
                self._load_imitation_actions(current_state)

            # Now actions are loaded
            # return the action
            if len(self._imitation_commands) > 0:
                return self._imitation_commands
            else:
                logging.warning("Task: Loaded action commands are empty!")
                return []
        except Exception as e:
            logging.warning("Task: Exception occurred in loading task! {}".format(e))
            return []
        finally:
            self.mutex.release()

    def set_imitation_actions(self, imitation_sequence):
        """
        Sets the imitation actions.
        :param imitation_sequence: List of actions
        :return:
        """
        assert self.task_type == TaskType.Imitation
        self.mutex.acquire()
        self._imitation_commands = imitation_sequence
        self.imitation_ready = True
        self.unsaved_changes = True
        self.mutex.release()

    def success(self, old_meta, new_meta, task_sequence):
        """
        Check if task is complete by observing the inventory state:
        x_meta = {
            'state': Dict(...),
            'action': OrderedDict(...),
            'reward': float,
            'done': bool
        }
        state is the original observation state from the Minecraft environment.
        :param old_meta: State before the action was performed.
        :param new_meta: State after the action was performed.
        :param task_sequence: Sequence of required resources given a task.
        :return:
        """
        # return false if states are not ready
        if old_meta is None or new_meta is None:
            return False
        old_state = old_meta['state']
        new_state = new_meta['state']
        # check if item is available in inventory
        if self.target in old_state['inventory']:
            diff = old_state['inventory'][self.target] - new_state['inventory'][self.target]
        else:
            # if not then return false
            diff = 0
            # print only once per task
            if self.print_warning:
                logging.warning('Task: Item {} not found in inventory'.format(self.target))
                self.print_warning = False
        val = min(abs(diff), len(task_sequence))
        return diff < 0, val

    def save(self):
        """
        Serializes the task instance and saves all necessary changes.
        :return:
        """
        self.mutex.acquire()
        try:
            dict_obj = {
                'id': self.id,
                'target': self.target,
                'data_dir_target': self.data_dir_target,
                'data_dir_transition': self.data_dir_transition,
                'median_task_steps': self.median_task_steps,
                'task_type': self.task_type,
                'data_id': self.data_id,
                'model_status': self.model_status
            }

            # create the dictionary to the task transition
            task_path = os.path.join(self.config.checkpoint_path, self.id)
            if not os.path.exists(task_path):
                logging.info("Task: Creating checkpoint path for {} task...".format(self.id))
                os.makedirs(task_path)

            # save the model
            if self.task_type == TaskType.Learning and self._model is not None:
                logging.info('Task: Saving learning task...')
                model_file = os.path.join(task_path, self.model_ckpt)
                torch.save(self._model.state_dict(), model_file)

            # save the task information
            task_file = os.path.join(task_path, self.task_file)
            logging.info("Task: Saving {} task object...".format(self.id))
            with open(task_file, 'wb') as f:
                pickle.dump(dict_obj, f)
            self.unsaved_changes = False
            logging.info('Task: Saving complete.')
        finally:
            self.mutex.release()

    @staticmethod
    def load(config, task_file):
        """
        Loads a task from a previous checkpoint.
        :param config: The parameter config.
        :param task_file: The task file.
        :return:
        """
        with open(task_file, 'rb') as ft:
            task_dict = pickle.load(ft)
            task = Task(config,
                        id=task_dict['id'],
                        target=task_dict['target'],
                        task_type=TaskType(task_dict['task_type']))
            task.data_dir_target = task_dict['data_dir_target']
            task.data_dir_transition = task_dict['data_dir_transition']
            task.median_task_steps = task_dict['median_task_steps']
            task.data_id = task_dict['data_id']
            task.model_status = ModelStatus(task_dict['model_status'])
            # assume that loading a consensus requires all data prepared in advance
            task.transition_data_ready = True
            task.target_data_ready = True
        return task
