import logging

from train.enums import ClientEnvOption

from train.monitor.consts import *
from train.monitor.timeouts import *
from train.statement.base_stmt import SubStatement, Statement
from train.statement.device_stmt import CheckDeviceStatement
from train.statement.prepare_env_stmt import PrepareEnvironmentStatement
from train.statement.prepare_env_stmt_local import PrepareLocalEnvironmentStatement
from train.task import Task


class InvalidOptionException(Exception):
    pass


class EvaluateAgentStatement(SubStatement):
    def __init__(self, stmt: Statement):
        super(EvaluateAgentStatement, self).__init__(stmt)
        self.mean_reward = UNINIT_EXPECTED_MEAN_REWARD
        if self.config.client_env_option == "all":
            self.option = ClientEnvOption.All
        elif self.config.client_env_option == "record":
            self.option = ClientEnvOption.Record
        elif self.config.client_env_option == "normal":
            self.option = ClientEnvOption.Normal
        elif self.config.client_env_option == "replay":
            self.option = ClientEnvOption.Replay
        else:
            raise InvalidOptionException()

    def exec(self):
        logging.info('Statement: Evaluation starts with {} task.'.format(self.task.id))
        stmt = CheckDeviceStatement(self)
        stmt.exec()
        assert stmt.success()

        # prepare the environment for the current task
        processes = []
        stmts = []

        # execute multiple evaluations in parallel
        for device in stmt.device_list:
            # reserve in env
            if self.config.env.env_server:
                stmt = PrepareEnvironmentStatement(self, num_envs=1, option=self.option)
            else:
                stmt = PrepareLocalEnvironmentStatement(self, num_envs=1, option=self.option)
            stmt.exec()
            assert stmt.success()
            stmts.append(stmt)
            # create a copy of the history object
            task_history = Task.clone_list(self.task_history)
            # evaluate the task history
            proc = self.ext_iface.evaluate_tasks(stmt.envs, task_history, device)
            proc.timeout = TIMEOUT_12_HOURS
            proc.run_async()
            processes.append(proc)

        # join back processes
        for p in processes:
            p.wait()

        # check if all processes finished successfully
        success = [p.success() for p in processes]
        assert all(success)

        # evaluate mean reward
        mean_rewards = [p.result if p.result is not None else 0.0 for p in processes]
        self.mean_reward = np.mean(mean_rewards)

        # finalize environments
        [s.finalize() for s in stmts]
        finalized = [s.envs is None for s in stmts]
        assert all(finalized)

        self._complete()
