import time
import warnings

import numpy as np

from stable_baselines3.common.callbacks import BaseCallback
from action_masking.util.sets import Zonotope

from action_masking.util.util import ActionSpace

TB_LOG = "benchmark_train/"
TB_LOG_SUP = "benchmark_train_sup/"


class TrainQuadrotorCallback(BaseCallback):

    """
    **Tensporboard metrics**
    # TODO: complete

    total_cbf_correction: This metric logs the CBF correction which is applied to the original action, 
                        accumulated over one episode.

    avg_cbf_correction: This metric logs the CBF correction which is applied to the original action, averaged over one episode.

    total_fail_safe_action: This metric counts the number of fail-safe actions per episode.
                            Fail-safe means we fall back to our fail-safe planner because we don't find another appropriate safe action.
                            The fail-safe planner usually just samples directly on our RCI set given the current state.
    safety_activity:  Basically, describes which ratio of the action space has to be reduced to be safe: 1 - area of safe space / maximum area of action space 
    
    total_safety_activity: Accumulated safety_activity for one episode.

    is_safety_violated: Defines if the actual selected final action is causing a safety violation.
                        This should be zero for all provable safe approaches. 
                        A (unsafe) baseline could trigger this.

    is_outside_safe_region: This metric indicates if we end up outside of the RCI set.
                            This should only be caused by a (unsafe) baseline controller.
                            Depending on the fail-safe controller used, we might leave the safe set but know that we will return within a few time steps when using this controller 
                            (is_outside_safe_region = True, is_safety_violated = False).

    total_distance_to_saferegion: If our state lies outside of the RCI set, this metric shows the accumulated distance to the RCI set over one episode.


    max_distance_to_saferegion: maximum distance to the RCI set over one episode.

    safe_space_area: Only for masking currently.
                    This logs the surface area of our safe action set per step.
                    Works for polytopes and zonotopes.

    safe_space_area_polytope: Only for masking currently.
                            This only gets logged when we use a zonotope representation.
                            In this case, we additionally log a polytope set synchronously.
                            This metric represents the surface area of our polytope safe action set per step.

    total_safe_space_area: Only for masking currently.
                        Accumulated safe_space_area for one episode.

    total_safe_space_area_polytope: Only for masking currently.
                                    Accumulated safe_space_area_polytope for one episode.

    total_safe_space_ratio: Only for masking currently.
                            This only gets logged when we use a zonotope representation.
                            Represents the ratio of safe_space_area / safe_space_area_polytope.
        
    """

    def __init__(self, safe_region, action_space: ActionSpace, action_space_area = 1, verbose=0):
        super(TrainQuadrotorCallback, self).__init__(verbose)
        self._safe_region = safe_region
        self._space = action_space
        self._action_space_area = action_space_area
        self.verbose = verbose
        self._reset()

    def _reset(self):

        self.is_safety_violated = 0
        self.total_env_reward = .0
        self.total_cbf_correction = .0
        self.total_fail_safe_action = .0
        self.total_safety_activity = .0
        self.is_outside_safe_region = False
        self.total_sampled_action = .0
        self.total_policy_action = .0
        self.total_pun_reward = .0
        self.total_x = .0
        self.total_xdot = .0
        self.total_z = .0
        self.total_zdot = .0
        self.total_theta = .0
        self.total_thdot = .0
        self.max_theta = .0
        self.max_thdot = .0
        # self.total_distance_to_saferegion = .0
        # self.max_distance_to_saferegion = .0
        self.infeasible_opt = .0
        self._prev_state = None
        self.safe_space_area = .0
        self.safe_space_area_polytope = None
        self.total_safe_space_area = .0
        self.total_safe_space_area_polytope = .0
        self.total_safe_space_ratio = .0
        self.episode_count = 0
        self.start_time = None

    def _on_step(self) -> bool:

        infos = self.locals.get("infos")[0]
        state = self.training_env.get_attr('state')[0]

        # Check if an episode has started and get start time
        if self.start_time is None:
            self.start_time = time.time()
            self.episode_count += 1

        # Check if an episode has ended
        if self.locals.get("dones")[0]:
            runtime = time.time() - self.start_time
            self.start_time = None
            self.logger.record(TB_LOG_SUP + "runtime_per_episode", runtime)

        # Wrapper information
        if "projection" in infos:
            wrapper_info = infos["projection"]
            if wrapper_info["infeasible"]:
                #self.infeasible_opt += 1
                self.total_fail_safe_action += 1
            if wrapper_info["cbf_correction"] is not None:
                self.total_safety_activity += 1
                self.total_cbf_correction += abs(wrapper_info["cbf_correction"])

        elif "masking" in infos:
            wrapper_info = infos["masking"]
            if wrapper_info["safe_space"] is not None:
                if self._space is ActionSpace.Discrete:
                    # Number of masked out (false) actions / number of actions
                    self.safety_activity = np.sum(wrapper_info["safe_space"]==0) / wrapper_info["safe_space"].shape[0]
                    self.total_safety_activity += self.safety_activity

                else: # Continuous action space
                    safe_space = wrapper_info["safe_space"]

                    # Zonotope set - If we run in Generator mode we will receive a zonotope safe space
                    if isinstance(safe_space, Zonotope):
                        self.safe_space_area = safe_space.volume
                        self.total_safe_space_area += self.safe_space_area
                        
                        # Check if we have additionally logged the polytope/interval safe space if we run in Generator mode 
                        # (depends on the hyperparam "log_polytope_space")
                        if wrapper_info.get("safe_space_polytope", None) is not None:
                            safe_space_polytope = wrapper_info["safe_space_polytope"]

                    # Polytope set - if we run in Interval mode we will receive a array in form of an array 
                    else:
                        # TODO: implement for higher dimensions
                        length = abs(safe_space[0][1]- safe_space[0][0]) 
                        width = abs(safe_space[1][1]- safe_space[1][0]) 
                        self.safe_space_area = length * width
                        self.total_safe_space_area += self.safe_space_area

                    # Compute safety activity as 1 - area of safe space / maximum area of action space
                    self.safety_activity = 1 - min(1,(self.safe_space_area / self._action_space_area))
                    self.total_safety_activity += self.safety_activity

                    # if polytope/interval safe space is additionally logged compute surface area
                    if wrapper_info.get("safe_space_polytope", None) is not None:
                        length = abs(safe_space_polytope[0][1]- safe_space_polytope[0][0])
                        width = abs(safe_space_polytope[1][1]- safe_space_polytope[1][0])
                        self.safe_space_area_polytope = length * width
                        self.total_safe_space_area_polytope += self.safe_space_area_polytope
                        self.total_safe_space_ratio += self.safe_space_area / self.safe_space_area_polytope

            if wrapper_info["fail_safe_action"] is not None:
                self.total_fail_safe_action += 1
                self.total_safety_activity += 1
        elif "replacement" in infos:
            wrapper_info = infos["replacement"]
            if wrapper_info["sample_action"] is not None:
                self.total_safety_activity += 1
                self.total_sampled_action += 1
            elif wrapper_info["fail_safe_action"] is not None:
                self.total_safety_activity += 1
                self.total_fail_safe_action += 1

        elif "baseline" in infos:
            wrapper_info = infos["baseline"]

        else: raise KeyError(f"No wrapper information in {infos}")

        # State information
        if self.verbose > 0:
            x, z, xdot, zdot, theta, thdot = state
            self.total_x += x
            self.total_xdot += xdot
            self.total_z += z
            self.total_zdot += zdot
            self.total_theta += theta
            self.total_thdot += thdot
            if theta > self.max_theta:
                self.max_theta = theta
            if thdot > self.max_thdot:
                self.max_thdot = thdot

        # Safety information
        if state not in self._safe_region: # .contains(state, 1e-10):
            self.is_outside_safe_region = True
            # For debugging purposes, where we want to find out if in the cases we left the RCI set -- is this far off the set or not.
            # TODO: euclidean_dist_to_safe_region is currently not implemented for zonotope safe regions!
            # distance_to_saferegion = self._safe_region.euclidean_dist_to_safe_region(state)

            # self.total_distance_to_saferegion += distance_to_saferegion

            # if distance_to_saferegion > self.max_distance_to_saferegion:
            #     self.max_distance_to_saferegion = distance_to_saferegion
            if "baseline" in infos or "fail_safe_action" in wrapper_info:
                print("previous state:", self._prev_state)
                print("current state:", state)
                #if "baseline" in infos or wrapper_info["fail_safe_action"] is None:
                self.is_safety_violated += 1
            elif "projection" in infos:
                print('projection lead outside of safe region')
                if not wrapper_info["infeasible"]:
                    self.is_safety_violated += 1

        # General information
        self.total_env_reward += wrapper_info["env_reward"]

        if wrapper_info["policy_action"] is not None:
            self.total_policy_action += abs(wrapper_info["policy_action"])

        if "pun_reward" in wrapper_info and wrapper_info["pun_reward"] is not None:
            self.total_pun_reward += wrapper_info["pun_reward"]

        if "episode" in infos:

            episode_infos = infos["episode"]
            episode_length = episode_infos['l']

            self.logger.record(TB_LOG + "is_safety_violation", self.is_safety_violated)
            self.logger.record(TB_LOG_SUP + "is_outside_safe_region", self.is_outside_safe_region)
            self.logger.record(TB_LOG + "avg_env_reward", self.total_env_reward / episode_length)
            self.logger.record(TB_LOG_SUP + "avg_policy_action", self.total_policy_action / episode_length)

            if self.verbose > 0:
                self.logger.record(TB_LOG_SUP + "avg_x", self.total_x / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_xdot", self.total_xdot / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_z", self.total_z / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_zdot", self.total_zdot / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_theta", self.total_theta / episode_length)
                self.logger.record(TB_LOG_SUP + "avg_thdot", self.total_thdot / episode_length)
                self.logger.record(TB_LOG_SUP + "episode_length", episode_length)
                self.logger.record(TB_LOG_SUP + "episode_time", episode_infos['t'])
                self.logger.record(TB_LOG_SUP + "episode_return", episode_infos['r'])
                self.logger.record(TB_LOG_SUP + "total_steps", self.model._total_timesteps)

                # self.logger.record(TB_LOG_SUP + "safe_space_area", self.safe_space_area)
                self.logger.record(TB_LOG_SUP + "avg_safe_space_area", self.total_safe_space_area/episode_length)

                if self.verbose > 1:
                    self.logger.record(TB_LOG_SUP + "max_theta", self.max_theta)
                    self.logger.record(TB_LOG_SUP + "max_thdot", self.max_thdot)
                    self.logger.record(TB_LOG_SUP + "total_env_reward", self.total_env_reward)
                    self.logger.record(TB_LOG_SUP + "total_policy_action", self.total_policy_action)
                    # Only needed for debugging:
                    # self.logger.record(TB_LOG_SUP + "total_distance_to_saferegion", self.total_distance_to_saferegion)
                    # self.logger.record(TB_LOG_SUP + "max_distance_to_saferegion", self.max_distance_to_saferegion)
                    # self.logger.record(TB_LOG_SUP + "safe_space_area", self.safe_space_area)
                    # self.logger.record(TB_LOG_SUP + "avg_safe_space_area", self.total_safe_space_area/episode_length)
                    
                    # If polytope/interval safe space is additionally logged log different metrics, 
                    # especially ratio of (zonotope safe space surface area)/(polytope safe space surface area)
                    if self.safe_space_area_polytope is not None:
                        self.logger.record(TB_LOG_SUP + "safe_space_area_polytope", self.safe_space_area_polytope)
                        self.logger.record(TB_LOG_SUP + "avg_safe_space_area_polytope", self.total_safe_space_area_polytope/episode_length)
                        self.logger.record(TB_LOG_SUP + "avg_safe_space_area_zonotope_polytope_ratio", self.total_safe_space_ratio/episode_length)
                        #self.logger.record(TB_LOG_SUP + "avg_safe_space_area_zonotope_polytope_ratio", self.total_safe_space_area/self.total_safe_space_area_polytope)
                        


            if "baseline" not in infos:
                self.logger.record(TB_LOG_SUP + "avg_pun_reward", self.total_pun_reward / episode_length)
                if self.verbose > 1:
                    self.logger.record(TB_LOG_SUP + "total_pun_reward", self.total_pun_reward)

                self.logger.record(TB_LOG + "avg_safety_activity", self.total_safety_activity / episode_length)
                if self.verbose > 1: self.logger.record(TB_LOG_SUP + "total_safety_activity", self.total_safety_activity)

                if "projection" not in infos:
                    self.logger.record(TB_LOG_SUP + "avg_fail_safe_action", self.total_fail_safe_action / episode_length)
                    if self.verbose > 1: self.logger.record(TB_LOG_SUP + "total_fail_safe_action", self.total_fail_safe_action)
                elif self.verbose > 1:
                    self.logger.record(TB_LOG_SUP + "total_cbf_correction", self.total_cbf_correction)
                    self.logger.record(TB_LOG_SUP + "avg_cbf_correction", self.total_cbf_correction / episode_length)
                    #self.logger.record(TB_LOG_SUP + "infeasible_opt", self.infeasible_opt)

                if "replacement" in infos and self.verbose > 1:
                    self.logger.record(TB_LOG_SUP + "total_sample_action", self.total_sampled_action)
                    self.logger.record(TB_LOG_SUP + "avg_sample_action", self.total_sampled_action / episode_length)

                # self._reset()
            self._reset()
            self.logger.dump(step=self.n_calls)

        self._prev_state = state

        return True