import logging

from train.enums import ExecMode, ClientEnvOption
from train.monitor.timeouts import *
from train.proc.bc_proc import BehaviouralCloningData
from train.statement.base_stmt import SubStatement, Statement
from train.statement.eval_stmt import EvaluateAgentStatement
from train.statement.prepare_env_stmt import PrepareEnvironmentStatement
from train.statement.prepare_env_stmt_local import PrepareLocalEnvironmentStatement
from train.model_registry import ModelStatus


class LearningStatement(SubStatement):
    def __init__(self, stmt: Statement):
        super(LearningStatement, self).__init__(stmt)

    def exec(self):
        # if agent does not already exist, then create and train an agent
        if self.config.mode == ExecMode.Train:
            logging.info('Statement: Entered learning execution mode.')

            # train model on existing data is model is untrained
            if self.task.model_status == ModelStatus.Untrained:
                stmt = BehaviouralCloningStatement(self)
                stmt.exec()
                assert stmt.success()
            else:
                logging.info('Statement: Model already trained on behavioural cloning data!')

            # fine-tune behavioural cloning policy with reinforcement learning
            if self.task.model_status == ModelStatus.GeneralBehaviourCloned:
                stmt = FineTuneStatement(self)
                stmt.exec()
                assert stmt.success()
            else:
                logging.info('Statement: Model already fine-tuned via reinforcement learning!')

        self._complete()


class BehaviouralCloningStatement(SubStatement):
    def __init__(self, stmt: Statement):
        super(BehaviouralCloningStatement, self).__init__(stmt)
        self.encoding = {v: k for k, v in self.config.subtask.consensus_code.items()}

    def exec(self):
        logging.info('Statement: Executing behavioural cloning process...')

        envs, num_envs = None, None
        if self.config.bc.use_eval_env:
            num_envs = self.config.env.num_env
            if self.config.env.env_server:
                stmt = PrepareEnvironmentStatement(self,
                                                   num_envs=num_envs,
                                                   option=ClientEnvOption.Replay)
            else:
                stmt = PrepareLocalEnvironmentStatement(self,
                                                        num_envs=num_envs,
                                                        option=ClientEnvOption.Replay)
            stmt.exec()
            assert stmt.success()
            envs = stmt.envs

        proc = self.ext_iface.clone_behaviour(self.task, data_type=BehaviouralCloningData.General, env=envs, n_workers=num_envs)
        proc.wait()
        assert proc.success() and proc.model is not None
        self.task.update_model(proc.model, ModelStatus.TaskBehaviourCloned)
        self.task.save()
        self._complete()


class FineTuneStatement(SubStatement):
    def __init__(self, stmt: Statement):
        super(FineTuneStatement, self).__init__(stmt)

    def exec(self):
        logging.info('Statement: Executing fine-tuning process...')
        # evaluate agent to decide if the policy tuning is necessary
        stmt = EvaluateAgentStatement(self)
        stmt.exec()
        assert stmt.success()

        stmt = EvalFineTuneOrSkipStatement(stmt)
        stmt.exec()
        assert stmt.success()
        self._complete()


class EvalFineTuneOrSkipStatement(SubStatement):
    def __init__(self, stmt: EvaluateAgentStatement):
        super(EvalFineTuneOrSkipStatement, self).__init__(stmt)

    def exec(self):
        if self.parent.mean_reward < self.task.expected_mean_reward:
            logging.info('Statement: Skipped fine-tuning due to sufficient performance: {} mean reward.'
                         .format(self.parent.mean_reward))
        else:
            logging.info('Statement: Evaluating statements for fine-tuning...')
            # prepare the environment for the current task
            if self.config.env.env_server:
                stmt = PrepareEnvironmentStatement(self, num_envs=self.config.env.num_env)
            else:
                stmt = PrepareLocalEnvironmentStatement(self, num_envs=self.config.env.num_env)
            stmt.exec()
            assert stmt.success()

            logging.info('Statement: Executing fine-tuning statement...')
            # fine-tune model by interacting in the environment
            proc = self.ext_iface.tune_policy(stmt.envs, self.task)
            proc.wait()
            assert proc.success() and proc.model is not None
            self.task.update_model(proc.model, ModelStatus.FineTuned)
            self.task.save()

            # release reserved environments
            stmt.finalize()

        self._complete()
