import numpy as np
from stable_baselines3.common.callbacks import BaseCallback
from time import process_time
import time

from action_masking.util.util import ActionSpace



class TDQuadrotorCallback(BaseCallback):
    TB_LOG = "benchmark_train/"
    TB_LOG_SUP = "benchmark_train_sup/"

    def __init__(self, safe_region, action_space: ActionSpace, action_space_area = 1, verbose=0, train: bool = True):
        if not train:
            self.TB_LOG = "benchmark_deploy/"
            self.TB_LOG_SUP = "benchmark_deploy_sup/"
        super(TDQuadrotorCallback, self).__init__()
        self._safe_region = safe_region
        self._space = action_space
        self._action_space_area = action_space_area
        self.verbose = verbose
        self._is_train = train
        self._reset()
        self._reset_episode()

    def _reset(self):
        self.start_time = None
        self.episode_count = 0
        self.is_safety_violated = 0
        self.total_env_reward = .0
        self.safe_space_area = .0
    
    def _reset_episode(self):
        self.total_env_reward = .0
        self.total_policy_action = .0
        self.total_safety_activity = .0
        self.total_fail_safe_action = .0
        self.total_safe_space_area = .0

    def _on_step(self):
        # from rlsampling
        dist = self.model.policy.action_dist
        self.model.rollout_buffer.add_distribution(
            dist.distribution.mean,
            dist.distribution.stddev,
            # self.model.policy.action_dist.distribution.mu, self.model.policy.action_dist.distribution.sigma
        )

        if self._is_train:
            infos = self.locals.get("infos")[0]
            dones = self.locals.get("dones")[0]
        else:
            # I have no idea, why ...
            infos = self.locals.get("info")[0]
            dones = self.locals.get("done")[0]

        # Check if an episode has started and get start time
        if self.start_time is None:
            self.start_time = time.time()
            self.episode_count += 1

        # Check if an episode has ended
        if dones:
            runtime = time.time() - self.start_time
            self.start_time = None
            self.logger.record(self.TB_LOG_SUP + "runtime_per_episode", runtime)

        # Wrapper information
        if "projection" in infos:
            wrapper_info = infos["projection"]
            if wrapper_info["cbf_correction"] is not None:
                self.total_safety_activity += 1
                self.total_cbf_correction += abs(wrapper_info["cbf_correction"])

        elif "masking" in infos:
            wrapper_info = infos["masking"]
            if wrapper_info["safe_space"] is not None:
                self.safe_space_area = wrapper_info["safe_space"].volume
                self.total_safe_space_area += self.safe_space_area

                if self._space is ActionSpace.Discrete:
                    self.total_safety_activity += np.count_nonzero(wrapper_info["safe_space"] == 0) / 21
                # else: self.total_safety_activity += 1 - ((wrapper_info["safe_space"][1] - wrapper_info["safe_space"][0]) / 60)

            if wrapper_info["fail_safe_action"] is not None:
                self.total_fail_safe_action += abs(wrapper_info["fail_safe_action"])

        elif "replacement" in infos:
            wrapper_info = infos["replacement"]
            if wrapper_info["sample_action"] is not None:
                self.total_safety_activity += 1
                self.total_sampled_action += abs(wrapper_info["sample_action"])
            elif wrapper_info["fail_safe_action"] is not None:
                self.total_safety_activity += 1
                self.total_fail_safe_action += abs(wrapper_info["fail_safe_action"])

        elif "baseline" in infos:
            wrapper_info = infos["baseline"]

        else: raise KeyError(f"No wrapper information in {infos}")

        self.total_env_reward += wrapper_info["env_reward"]

        if "episode" in infos:
            episode_infos = infos["episode"]
            episode_length = episode_infos['l']


            self.logger.record(self.TB_LOG + "is_safety_violation", self.is_safety_violated)
            self.logger.record(self.TB_LOG + "avg_env_reward", self.total_env_reward / episode_length)
            self.logger.record(self.TB_LOG_SUP + "avg_policy_action", self.total_policy_action / episode_length)

            if self.verbose > 0:
                self.logger.record(self.TB_LOG_SUP + "episode_length", episode_length)
                self.logger.record(self.TB_LOG_SUP + "episode_time", episode_infos['t'])
                self.logger.record(self.TB_LOG_SUP + "episode_return", episode_infos['r'])
                self.logger.record(self.TB_LOG_SUP + "total_steps", self.model._total_timesteps)


            if "baseline" not in infos:
                self.logger.record(self.TB_LOG_SUP + "avg_safe_space_area", self.total_safe_space_area/episode_length)

                self.logger.record(self.TB_LOG + "avg_safety_activity", self.total_safety_activity / episode_length)
                if self.verbose > 1: self.logger.record(self.TB_LOG_SUP + "total_safety_activity", self.total_safety_activity)

            self._reset_episode()


        self.logger.dump(step=self.n_calls)
        self._reset()
        return True

    def _on_rollout_end(self) -> None:
        # From rlsampling
        actions = self.model.rollout_buffer.actions
        self.logger.record("train/action_mean", np.mean(actions))
        self.logger.record("train/action_std", np.std(actions))

        mean_actions = self.model.rollout_buffer.mean_actions
        # stds = self.model.rollout_buffer.stds
        self.logger.record("policy/mean_action_dist_mean", np.mean(mean_actions))
        self.logger.record("policy/std_action_dist_mean", np.std(mean_actions))
