import warnings

import numpy as np

from stable_baselines3.common.callbacks import BaseCallback

from provably_safe_benchmark.util.util import ActionSpace

TB_LOG = "benchmark_train/"
TB_LOG_SUP = "benchmark_train_sup/"


class TrainQuadrotorCallback(BaseCallback):

    def __init__(self, safe_region, action_space: ActionSpace, action_space_area = 1, verbose=0):
        super(TrainQuadrotorCallback, self).__init__(verbose)
        self._safe_region = safe_region
        self._space = action_space
        self._action_space_area = action_space_area
        self.verbose = verbose
        self._reset()

    def _reset(self):

        self.is_safety_violated = False
        self.total_env_reward = .0
        self.total_cbf_correction = .0
        self.total_fail_safe_action = .0
        self.total_safety_activity = .0
        self.is_outside_safe_region = False
        self.total_sampled_action = .0
        self.total_policy_action = .0
        self.total_pun_reward = .0
        self.total_x = .0
        self.total_xdot = .0
        self.total_z = .0
        self.total_zdot = .0
        self.total_theta = .0
        self.total_thdot = .0
        self.max_theta = .0
        self.max_thdot = .0
        self.distance_to_saferegion = .0
        self.max_distance_to_saferegion = .0
        self.infeasible_opt = .0
        self._prev_state = None

    def _on_step(self) -> bool:

        infos = self.locals.get("infos")[0]
        state = self.training_env.get_attr('state')[0]

        # Wrapper information
        if "projection" in infos:
            wrapper_info = infos["projection"]
            if wrapper_info["infeasible"]:
                self.infeasible_opt += 1
            if wrapper_info["cbf_correction"] is not None:
                self.total_safety_activity += 1
                self.total_cbf_correction += abs(wrapper_info["cbf_correction"])

        elif "masking" in infos:
            wrapper_info = infos["masking"]
            if wrapper_info["safe_space"] is not None:
                if self._space is ActionSpace.Discrete:
                    # Number of masked out (false) actions / number of actions
                    self.total_safety_activity += np.sum(wrapper_info["safe_space"]==0) / wrapper_info["safe_space"].shape[0]
                else:
                    # 1 - area of safe space / maximum area of action space
                    self.total_safety_activity += 1 - min(1,(
                        ((wrapper_info["safe_space"][0, 1] - wrapper_info["safe_space"][0, 0]) * 
                         (wrapper_info["safe_space"][1, 1] - wrapper_info["safe_space"][1, 0])) 
                        / self._action_space_area))
            else:
                self.total_safety_activity += 1

            if wrapper_info["fail_safe_action"] is not None:
                self.total_fail_safe_action += abs(wrapper_info["fail_safe_action"])

        elif "replacement" in infos:
            wrapper_info = infos["replacement"]
            if wrapper_info["sample_action"] is not None:
                self.total_safety_activity += 1
                self.total_sampled_action += abs(wrapper_info["sample_action"])
            elif wrapper_info["fail_safe_action"] is not None:
                self.total_safety_activity += 1
                self.total_fail_safe_action += abs(wrapper_info["fail_safe_action"])

        elif "baseline" in infos:
            wrapper_info = infos["baseline"]

        else: raise KeyError(f"No wrapper information in {infos}")

        # State information
        if self.verbose > 0:
            x, z, xdot, zdot, theta, thdot = state
            self.total_x += x
            self.total_xdot += xdot
            self.total_z += z
            self.total_zdot += zdot
            self.total_theta += theta
            self.total_thdot += thdot
            if theta > self.max_theta:
                self.max_theta = theta
            if thdot > self.max_thdot:
                self.max_thdot = thdot

        # Safety information
        if not self._safe_region.contains(state, 1e-10):
            self.is_outside_safe_region = True
            temp = self._safe_region.euclidean_dist_to_safe_region(state)
            self.distance_to_saferegion += temp
            self.max_distance_to_saferegion = max(self.max_distance_to_saferegion, temp)
            if "baseline" in infos or "fail_safe_action" in wrapper_info:
                #print(self._prev_state)
                if "baseline" in infos or wrapper_info["fail_safe_action"] is None:
                    self.is_safety_violated = True
            elif "projection" in infos:
                print('projection lead outside of safe region')
                if not wrapper_info["infeasible"]:
                    self.is_safety_violated = True

        # General information
        self.total_env_reward += wrapper_info["env_reward"]

        if wrapper_info["policy_action"] is not None:
            self.total_policy_action += abs(wrapper_info["policy_action"])

        if "pun_reward" in wrapper_info and wrapper_info["pun_reward"] is not None:
            self.total_pun_reward += wrapper_info["pun_reward"]

        if "episode" in infos:

            episode_infos = infos["episode"]
            episode_length = episode_infos['l']

            self.logger.record(TB_LOG + "is_safety_violation", self.is_safety_violated)
            self.logger.record(TB_LOG_SUP + "is_outside_safe_region", self.is_outside_safe_region)
            self.logger.record(TB_LOG + "avg_env_reward", self.total_env_reward / episode_length)
            self.logger.record(TB_LOG_SUP + "avg_policy_action", self.total_policy_action / episode_length)

            if self.verbose > 0:
                self.logger.record(TB_LOG_SUP + "avg_x", self.total_x / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_xdot", self.total_xdot / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_z", self.total_z / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_zdot", self.total_zdot / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_theta", self.total_theta / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_thdot", self.total_thdot / episode_length)
                self.logger.record(TB_LOG_SUP + "episode_length", episode_length)
                self.logger.record(TB_LOG_SUP + "episode_time", episode_infos['t'])
                self.logger.record(TB_LOG_SUP + "episode_return", episode_infos['r'])

                if self.verbose > 1:
                    self.logger.record(TB_LOG_SUP + "max_theta", self.max_theta)
                    self.logger.record(TB_LOG_SUP + "max_thdot", self.max_thdot)
                    self.logger.record(TB_LOG_SUP + "total_env_reward", self.total_env_reward)
                    self.logger.record(TB_LOG_SUP + "total_policy_action", self.total_policy_action)
                    self.logger.record(TB_LOG_SUP + "distance_to_saferegion", self.distance_to_saferegion)
                    self.logger.record(TB_LOG_SUP + "max_distance_to_saferegion", self.max_distance_to_saferegion)

            if "baseline" not in infos:
                self.logger.record(TB_LOG_SUP + "avg_pun_reward", self.total_pun_reward / episode_length)
                if self.verbose > 1:
                    self.logger.record(TB_LOG_SUP + "total_pun_reward", self.total_pun_reward)

                self.logger.record(TB_LOG + "avg_safety_activity", self.total_safety_activity / episode_length)
                if self.verbose > 1: self.logger.record(TB_LOG_SUP + "total_safety_activity", self.total_safety_activity)

                if "projection" not in infos:
                    self.logger.record(TB_LOG_SUP + "avg_fail_safe_action", self.total_fail_safe_action / episode_length)
                    if self.verbose > 1: self.logger.record(TB_LOG_SUP + "total_fail_safe_action", self.total_fail_safe_action)
                elif self.verbose > 1:
                    self.logger.record(TB_LOG_SUP + "total_cbf_correction", self.total_cbf_correction)
                    self.logger.record(TB_LOG_SUP + "avg_cbf_correction", self.total_cbf_correction / episode_length)
                    self.logger.record(TB_LOG_SUP + "infeasible_opt", self.infeasible_opt)

                if "replacement" in infos and self.verbose > 1:
                    self.logger.record(TB_LOG_SUP + "total_sample_action", self.total_sampled_action)
                    self.logger.record(TB_LOG_SUP + "avg_sample_action", self.total_sampled_action / episode_length)

            self.logger.dump(step=self.n_calls)
            self._reset()

        self._prev_state = state

        return True