import logging
import os

import torch

from train.proc.bc_proc import BehaviouralCloningProc, BehaviouralCloningData
from train.proc.env_client_proc import EnvClientProc
from train.enums import ClientEnvOption
from train.proc.env_server_proc import EnvServerProc
from train.proc.eval_proc import EvaluateProc
from train.proc.rl_proc import ReinforcementLearningProc
from train.proc.subtask_proc import SubtaskIdentifierProc
from train.task import Task


class ExternalInterface:
    """
    Interface for all modules.
    """
    def __init__(self, config):
        self.config = config

    def get_device_list(self):
        """
        Checks and returns the list of devices, if the settings are matching the number of hardware units.
        :return:
        """
        device_list = []
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if 'cuda' == device and device == self.config.env.device:
            # check how many devices are available
            num_devices = torch.cuda.device_count()
            # check if number of requested devices is below or equal to available devices
            if self.config.monitor.num_evaluate_tasks_env <= num_devices:
                num_devices = self.config.monitor.num_evaluate_tasks_env
            else:
                logging.warning('IFace: Too many instances requested! Reduced to only {} gpu devices.'
                                .format(num_devices))
            # assign multiple gpu instances
            for i in range(num_devices):
                device_list.append(torch.device('{}:{}'.format(device, i)))
        else:
            # else assign requested number of cpu instances
            for i in range(self.config.monitor.num_evaluate_tasks_env):
                device_list.append(torch.device('cpu'))
        return device_list

    def create_paths(self):
        """
        Creates and checks all required paths for the modules to work.
        :return:
        """
        for file in [self.config.checkpoint_path,
                     self.config.performance_path,
                     self.config.subtask.outputdir,
                     self.config.rl.record_dir,
                     self.config.monitor.record_path,
                     self.config.monitor.experiences_path,
                     self.config.debug.subtask_outputdir,
                     self.config.bc.param_root,
                     self.config.bc.recording_dir]:
            logging.info('IFace: Creating {} path...'.format(file))
            os.makedirs(file, exist_ok=True)

    def save_consensus(self, tasks, consensus_file: str):
        """
        Persists the extracted task checkpoints from the consensus.
        :param tasks:
        :param consensus_file:
        :return:
        """
        for task in tasks:
            task.save()
        with open(consensus_file, 'w') as f:
            [f.write("%s\n" % c.id) for c in tasks]
        logging.info('IFace: Created and saved new consensus.')

    def get_consensus(self):
        """
        Loads existing task consensus file if checkpoints exist or creates a new task consensus.
        :return: List of Tasks based on the consensus
        """
        consensus_file = os.path.join(self.config.checkpoint_path, self.config.checkpoint_file)
        if self.config.debug.enabled and self.config.debug.recreate_consensus_on_startup \
                or not os.path.exists(consensus_file):
            logging.info('IFace: Creating new consensus...')
            proc = self.create_subtask_alignment()
            proc.wait()
            tasks = proc.consensus
            self.save_consensus(tasks, consensus_file)
        else:
            logging.info('IFace: Loading existing consensus...')
            tasks = self.load_consensus()
        return tasks

    def load_consensus(self):
        """
        Returns a consensus object for testing in the environment.
        :return: Task list which represents the consensus for the agent schedule.
        """
        logging.info('IFace: Loading pre-trained consensus file.')
        consensus_file = os.path.join(self.config.checkpoint_path, self.config.checkpoint_file)
        tasks = []
        with open(consensus_file, 'r') as fc:
            lines = fc.read().splitlines()
        for line in lines:
            task_path = os.path.join(self.config.checkpoint_path, line)
            task_file = os.path.join(task_path, 'task.ckpt')
            task = Task.load(self.config, task_file)
            tasks.append(task)
        logging.info('IFace: Loaded {} consensus elements.'.format(len(tasks)))
        return tasks

    def create_subtask_alignment(self) -> SubtaskIdentifierProc:
        """
        This function should execute the alignment on the original Minecraft dataset and
        break out subtasks for each behavioral cloning agent to pre-train initial agents.
        :return: Returns the consensus sequence from the original data.
        """
        logging.info('IFace: Creating subtask alignment process.')
        return SubtaskIdentifierProc(self.config)

    def create_env_pool(self, num_envs: int) -> EnvServerProc:
        """
        Creates the Minecraft environments.
        :param num_envs: The number of environments to create.
        :return:
        """
        logging.info('IFace: Creating environment server pool process.')
        return EnvServerProc(self.config, num_envs)

    def make_env(self, num_envs: int, option: ClientEnvOption = ClientEnvOption.Normal,
                 seed_list=None, env_server: bool=True) -> EnvClientProc:
        """
        Creates environment client to connect to remote environment
        :param num_envs: The number of environments to create.
        :param option: Option for normal, recorded or experience replayed environments.
        :param seed_list: If seeded environment is required.
        :return: Environment communication proxy.
        """
        logging.info('IFace: Preparing {} environments.'.format(num_envs))
        return EnvClientProc(config=self.config,
                             num_envs=num_envs,
                             option=option,
                             seed_list=seed_list,
                             env_server=env_server)

    def clone_behaviour(self, task: Task, data_type: BehaviouralCloningData = BehaviouralCloningData.General,
                        env=None, n_workers: int = None) -> BehaviouralCloningProc:
        """
        Executes the behavioural cloning agent to learn the subtasks for later initial
        reinforcement learning policy.
        :param task: Task to clone behaviour.
        :param data_type: Data type for training.
        :param env: Env if eval hook is used.
        :param n_workers: Number of used envs for eval hook
        :return: Returns process object.
        """
        logging.info("IFace: Creating behavioural cloning process for {} task.".format(task.id))
        return BehaviouralCloningProc(self.config, task, data_type, env, n_workers)

    def tune_policy(self, envs, task: Task) -> ReinforcementLearningProc:
        """
        Start PPO or RUDDER in the environment and runs until the policy fine-tuned on the given task.
        Be careful: This method interacts with the environment!
        :param envs: The gym environments where to initially place the agent.
        :param task: The task to fine-tune the policy.
        :return: Returns process object.
        """
        logging.info("IFce: Creating reinforcement learning process for {} task.".format(task.id))
        return ReinforcementLearningProc(self.config, envs, task)

    def evaluate_tasks(self, env, tasks, device) -> EvaluateProc:
        """
        Executes the task based on the agent checkpoint.
        Be careful: This method interacts with the environment!
        :param env: The gym environments where to initially place the agent.
        :param tasks: The task history to evaluate.
        :param device: The device where to load the model.
        :return: Returns process object.
        """
        logging.info('IFce: Creating task evaluation process.')
        return EvaluateProc(self.config, env, tasks, device)
