import numpy as np

from stable_baselines3.common.callbacks import BaseCallback

from provably_safe_benchmark.util.util import ActionSpace
from time import process_time

TB_LOG = "benchmark_deploy/"
TB_LOG_SUP = "benchmark_deploy_sup/"


class DeployPendulumCallback(BaseCallback):

    def __init__(self, safe_region, action_space: ActionSpace, verbose=0):
        super(DeployPendulumCallback, self).__init__(verbose)
        self._safe_region = safe_region
        self._space = action_space
        self.verbose = verbose

    def _on_rollout_start(self) -> None:
        state = self.training_env.get_attr('state')[0]

        # State information
        self.logger.record(TB_LOG + "theta", state[0])
        self.logger.record(TB_LOG + "thdot", state[1])

        # Safety information
        if state not in self._safe_region:
            self.logger.record(TB_LOG_SUP + "is_outside_safe_region", True)
            self.logger.record(TB_LOG + "is_safety_violation", True)
        else:
            self.logger.record(TB_LOG_SUP + "is_outside_safe_region", False)
            self.logger.record(TB_LOG + "is_safety_violation", False)

        self.logger.dump(step=self.n_calls)
        self._start_time = process_time()


    def _on_step(self) -> bool:

        infos = self.locals.get('info')[0]

        if "episode" in infos:

            state = infos["terminal_observation"]
            state = [np.arcsin(state[1]), state[2]] # Workaround

            if self.verbose > 0:

                self.logger.record(TB_LOG + "process_time", process_time() - self._start_time)

                episode_infos = infos["episode"]
                self.logger.record(TB_LOG_SUP + "episode_length", episode_infos['l'])
                self.logger.record(TB_LOG_SUP + "episode_time", episode_infos['t'])
                self.logger.record(TB_LOG_SUP + "episode_return", episode_infos['r'])
        else:
            state = self.training_env.get_attr('state')[0]

        # Wrapper information
        if "projection" in infos:
            wrapper_info = infos["projection"]
            if wrapper_info["cbf_correction"] is not None:
                self.logger.record(TB_LOG + "safety_activity", True)
                self.logger.record(TB_LOG_SUP + "cbf_correction", wrapper_info["cbf_correction"])
            else:
                self.logger.record(TB_LOG + "safety_activity", False)
                self.logger.record(TB_LOG_SUP + "cbf_correction", 0)

        elif "masking" in infos:
            wrapper_info = infos["masking"]

            if wrapper_info["safe_space"] is not None:
                if self._space is ActionSpace.Discrete:
                    self.logger.record(TB_LOG + "safety_activity", np.count_nonzero(wrapper_info["safe_space"] == 0) / 21)
                else: self.logger.record(TB_LOG + "safety_activity", 1 - ((wrapper_info["safe_space"][1] - wrapper_info["safe_space"][0]) / 60))
            else: self.logger.record(TB_LOG + "safety_activity", True)

            if wrapper_info["fail_safe_action"] is not None:
                self.logger.record(TB_LOG_SUP + "fail_safe_action", wrapper_info["fail_safe_action"])
            else: self.logger.record(TB_LOG_SUP + "fail_safe_action", 0)

        elif "replacement" in infos:
            wrapper_info = infos["replacement"]
            if wrapper_info["sample_action"] is not None:
                self.logger.record(TB_LOG_SUP + "fail_safe_action", 0)
                self.logger.record(TB_LOG + "safety_activity", True)
                self.logger.record(TB_LOG_SUP + "sample_action", wrapper_info["sample_action"])
            elif wrapper_info["fail_safe_action"] is not None:
                self.logger.record(TB_LOG_SUP + "fail_safe_action", wrapper_info["fail_safe_action"])
                self.logger.record(TB_LOG + "safety_activity", True)
                self.logger.record(TB_LOG_SUP + "sample_action", 0)
            else:
                self.logger.record(TB_LOG_SUP + "fail_safe_action", 0)
                self.logger.record(TB_LOG + "safety_activity", False)
                self.logger.record(TB_LOG_SUP + "sample_action", 0)

        elif "baseline" in infos:
            wrapper_info = infos["baseline"]

        else: raise KeyError(f"No wrapper information in {infos}")

        # State information
        self.logger.record(TB_LOG + "theta", state[0])
        self.logger.record(TB_LOG + "thdot", state[1])

        # Safety information
        if state not in self._safe_region:
            self.logger.record(TB_LOG_SUP + "is_outside_safe_region", True)
            if "baseline" in infos or "projection" in infos or wrapper_info["fail_safe_action"] is None:
                self.logger.record(TB_LOG + "is_safety_violation", True)
            else: self.logger.record(TB_LOG + "is_safety_violation", False)
        else:
            self.logger.record(TB_LOG + "is_safety_violation", False)
            self.logger.record(TB_LOG_SUP + "is_outside_safe_region", False)

        # General information
        self.logger.record(TB_LOG + "env_reward", wrapper_info["env_reward"])
        if wrapper_info["policy_action"] is not None:
            self.logger.record(TB_LOG_SUP + "policy_action", wrapper_info["policy_action"])
        if "pun_reward" in wrapper_info and wrapper_info["pun_reward"] is not None:
            self.logger.record(TB_LOG_SUP + "pun_reward", wrapper_info["pun_reward"])

        self.logger.dump(step=self.n_calls)
        return True