import gym
import warnings
import numpy as np
from functools import reduce
import cvxpy as cp
import gurobipy
from provably_safe_benchmark.sb3_contrib.common.utils import fetch_fn
from provably_safe_benchmark.util.tictoc import tic, toc  # noqa: F401
from provably_safe_benchmark.util.util import remove_redundant_constraints


class ActionProjectionWrapper(gym.Wrapper):
    """
   :param env: Gym environment
   :param safe_region: Safe region instance
   :param action_limits: Action limits in [[min_1,...,min_n],[max_1,...,max_n]] format
   :param actuated_dynamics_fn: Actuated dynamics function
   :param unactuated_dynamics_fn: Unactuated dynamics function
   :param dynamics_fn: Dynamics function
   :param noise_set: Vertices of the noise set. If None, the noise set is assumed to be the zero set.
   :param punishment_fn: Reward punishment function
   :param alter_action_space: Alternative gym action space
   :param transform_action_space_fn: Action space transformation function
   :param inv_transform_action_space_fn: Inverse action space transformation function
   :param generate_wrapper_tuple: Generate tuple (wrapper action, environment reward)
   :param gamma: CBF gamma
   :param safety_margin: Safety margin for projection.
   :param use_zero_trick: remove redundant constraints that are only active in one of two action dimensions
   """

    def __init__(self,
                 env: gym.Env,
                 safe_region,
                 action_limits,
                 actuated_dynamics_fn,
                 unactuated_dynamics_fn=None,
                 dynamics_fn=None,
                 noise_set=None,
                 punishment_fn=None,
                 alter_action_space=None,
                 transform_action_space_fn=None,
                 generate_wrapper_tuple=False,
                 inv_transform_action_space_fn=None,
                 safe_control_fn=None,
                 gamma=1,
                 safety_margin=0.00,
                 use_zero_trick=True):

        super().__init__(env)
        self._alpha = gamma
        self._safe_region = safe_region
        self._generate_wrapper_tuple = generate_wrapper_tuple

        if unactuated_dynamics_fn is None and dynamics_fn is None:
            raise ValueError("Specify ``dynamics_fn`` or ``unactuated_dynamics``")

        if not hasattr(self.env, "state"):
            warnings.warn("Environment has no attribute ``state``")

        self._dynamics_fn = fetch_fn(self.env, dynamics_fn)
        self._g = fetch_fn(self.env, actuated_dynamics_fn)
        self._f = fetch_fn(self.env, unactuated_dynamics_fn)
        self._punishment_fn = fetch_fn(self.env, punishment_fn)
        self._transform_action_space_fn = fetch_fn(self.env, transform_action_space_fn)
        self._safe_control_fn = fetch_fn(self.env, safe_control_fn)
        self._noise_set = noise_set

        if not hasattr(self.env, "action_space"):
            warnings.warn("Environment has no attribute ``action_space``")

        if isinstance(self.action_space, gym.spaces.Box):
            num_actions = reduce(lambda i, j: i * j, self.action_space.shape)
        else:
            raise ValueError(f"{type(self.action_space)} not supported")

        if alter_action_space is not None:
            self.action_space = alter_action_space
            if transform_action_space_fn is None:
                warnings.warn("Set ``alter_action_space`` but no ``transform_action_space_fn``")
            elif generate_wrapper_tuple and inv_transform_action_space_fn is None:
                warnings.warn("``generate_wrapper_tuple`` but no ``inv_transform_action_space_fn``")
            else:
                self._inv_transform_action_space_fn = fetch_fn(self.env, inv_transform_action_space_fn)

        # Half-space representation
        self._C, self._d = safe_region.halfspaces
        if self._noise_set is not None and self._noise_set.size > 0:
            # Extend the half-space constraints by the noise set.
            n_noise_vertices = self._noise_set.shape[0]
            d_noise = np.tile(self._d, n_noise_vertices)
            constraint_length = self._d.shape[0]
            for i in range(n_noise_vertices):
                d_noise[i*constraint_length:(i+1)*constraint_length] += - self._C @ (self.env.E_d @ self._noise_set[i])
            self._d = d_noise
            self._C = np.tile(self._C, (n_noise_vertices, 1))
        self.action_limits = action_limits
        # QP objective
        self._P = np.eye(num_actions)
        self._q = np.zeros(num_actions)
        self.use_zero_trick = use_zero_trick
        # Duplicate trick not fast enough for most QPs.
        self.use_duplicate_trick = False
        self.use_reach_trick = False
        self.last_projected = False
        self.safety_margin = safety_margin
        self._infeasible_step = False

    def step(self, action):
        r"""project the action to the safe set if necessary using CBF.

        The action projection is based on CBF (https://arxiv.org/pdf/1903.11199.pdf).
        The general formulation is
            \dot{h}(x, u) >= - \alpha(h(x)).
        We choose a linear CBF with \alpha(x) = \alpha = constant.
        We discretize the CBF
            \Delta h(x_{t}, u_{t}) >= - \alpha * h(x_{t})
            h(x_{t+1}) - h{x_{t}} >= - \alpha * h(x_{t})
            h(x_{t+1}) >= (1 - \alpha) * h(x_{t})
        Using h(x) = -C x + d, and a = a_{\pi} + \tilda{a}, we have
            -C x_{t+1} + d >= (1 - \alpha) * (-C x_{t} + d)
            -C(f(x_{t}) + g(x_{t}) (a_{\pi} + \tilda{a})) + d >= (1 - \alpha) * (-C x_{t} + d)
            C g(x_{t}) \tilda{a} <= C ((1 - \alpha) x_{t} - (f(x_{t}) + g(x_{t}) a_{\pi})) + \alpha d
            G \tilda{a} <= h
        """
        self._infeasible_step = False
        if isinstance(self.action_space, gym.spaces.Box) and self.action_space.shape[0] == 1:
            action = action.item()

        # Optional action transformation
        if self._transform_action_space_fn is not None:
            action = self._transform_action_space_fn(action)
        # Check if action is safe
        if not self._safe_region.contains(self._dynamics_fn(self.env, action), -1e-5):
            state = self.env.state
            # QP constraint
            G = self._C @ self._g(self.env)
            d = self._d - self.safety_margin * np.abs(self._d)
            if self._f is not None:
                h = self._C @ (
                        (1-self._alpha) * state - (self._f(self.env) + np.dot(self._g(self.env), action))
                    ) + self._alpha * d
            else:
                h = self._C @ ((1-self._alpha) * state - self._dynamics_fn(self.env, action)) + self._alpha * d
            # Construct action space for projection term of the action (a = a_{\pi} + \tilda{a})
            a_proj_min = self.action_limits[0] - action
            a_proj_max = self.action_limits[1] - action
            a_proj_space = np.array([a_proj_min, a_proj_max])
            # --------------------
            # >> Reduce the constraints to the constraints that are reachable by the system. <<
            # Cg is linear
            # Therefore, if Ga <= h for all vertices of the action space, then
            # Ga <= h for all points in the action space.
            # I.e., if the constraint is inactive for the extremata of the action space,
            # it is inactive for the whole action space.
            if self.use_reach_trick:
                n_action_vertices = 2**a_proj_space.shape[1]
                vert_ids = np.arange(n_action_vertices)
                m = a_proj_space.shape[1]
                action_space_idx = (((vert_ids[:, None] & (1 << np.arange(m)))) > 0).astype(int)
                action_space_vertices = np.array(
                    [[a_proj_space[action_space_idx[i, j], j] for j in range(a_proj_space.shape[1])]
                        for i in range(n_action_vertices)]
                )
                G_rep = np.repeat(G[np.newaxis, ...], n_action_vertices, axis=0)
                h_rep = np.repeat(h[np.newaxis, ...], n_action_vertices, axis=0)
                if len(G_rep.shape) == 2:
                    h_act_vert = np.einsum("ij,ik->ij", G_rep, action_space_vertices)
                else:
                    h_act_vert = np.einsum("ijk,ik->ij", G_rep, action_space_vertices)
                constraint_active = np.array(1 - np.all(h_act_vert <= h_rep, axis=0), dtype=np.bool8)
                G = G[constraint_active]
                h = h[constraint_active]
            # --------------------
            # >> Add the action limit constraints. <<
            G_limits = np.concatenate([np.eye(a_proj_space.shape[1]), -np.eye(a_proj_space.shape[1])], axis=0)
            G = np.append(G, G_limits, axis=0)
            h_limits = np.concatenate([a_proj_space[1], -a_proj_space[0]], axis=0)
            h = np.append(h, h_limits, axis=0)
            # --------------------
            # >> Reduce the constraints that are zero in one dimension. <<
            if self.use_zero_trick and G.shape[1] == 2:
                G_zero = (G == 0)
                # First, create new halfspaces that are not zero in one dimension
                simplified_pos = np.all(1-G_zero, axis=1)
                simplified_G = G[simplified_pos]
                simplified_h = h[simplified_pos]
                #new_A = []
                new_b = []
                dependent_G = G[np.invert(simplified_pos)]
                dependent_h = h[np.invert(simplified_pos)]
                row_norm = np.sqrt((dependent_G * dependent_G).sum(axis=1))
                row_norm_A = dependent_G / row_norm.reshape(len(row_norm), 1)
                row_norm_b = dependent_h / row_norm

                dependent_A = np.array([[1.0,0.0], [-1.0, 0.0],[0.0, 1.0],[0.0, -1.0]])
                for i in range(4):
                    index = np.intersect1d(np.where(row_norm_A[:,0] == dependent_A[i,0])[0], np.where(row_norm_A[:,1] == dependent_A[i,1])[0])
                    b_min = np.min(row_norm_b[index])
                    new_b.append(b_min)

                simplified_G = np.vstack((simplified_G, dependent_A))
                simplified_h = np.append(simplified_h, new_b)

            else:
                simplified_G = G
                simplified_h = h

            # Remove duplicate direction constraints
            if self.use_duplicate_trick:
                unique_G, unique_h = remove_redundant_constraints(simplified_G,simplified_h)
            else:
                unique_G = simplified_G
                unique_h = simplified_h

            # Solve QP
            x = cp.Variable(self.action_limits[0].shape[0])
            prob = cp.Problem(cp.Minimize((1/2)*cp.quad_form(x, self._P) + self._q.T @ x),
                              [unique_G @ x <= unique_h])
            try:
                prob.solve(solver='GUROBI', reoptimize=True)
                if prob.status == "infeasible" or prob.status == "infeasible_inaccurate":
                    print("[WARNING] QP is infeasible. Using failsafe action.")
                    safe_action = self._safe_control_fn(self.env, self._safe_region)
                    sol = safe_action-action
                    print("Fail safe solution after QP infeasible is safe: {}".format(self._safe_region.contains(self.env.dynamics_fn(safe_action, state), 1e-10)))
                    self._infeasible_step = True
                else:
                    sol = x.value
                    if max(unique_G @ sol - unique_h) > 0.0:
                        # print("optimization problem solved but constraints not all fulfilled by margin {}".format(max(unique_G @ sol - unique_h)))
                        safe_action = self._safe_control_fn(self.env, self._safe_region)
                        sol = safe_action - action
                        self._infeasible_step = True
                    else:
                        self._infeasible_step = False
            except cp.SolverError as e:
                print("Solver Error: {}".format(e))
                safe_action = self._safe_control_fn(self.env, self._safe_region)
                sol = safe_action-action
                self._infeasible_step = True
            safe_action = action + sol


            obs, reward, done, info = self.env.step(safe_action)
            info["projection"] = {"env_reward": reward, "cbf_correction": sol, "infeasible": self._infeasible_step}

            if self._generate_wrapper_tuple:
                wrapper_action = self._inv_transform_action_space_fn(safe_action) \
                    if self._inv_transform_action_space_fn is not None else safe_action
                info["wrapper_tuple"] = (np.asarray([wrapper_action]), np.asarray([reward], dtype=np.float32))

            # Optional reward punishment
            if self._punishment_fn is not None:
                punishment = self._punishment_fn(self.env, action, reward, sol)
                info["projection"]["pun_reward"] = punishment
                reward = punishment
            else:
                info["projection"]["pun_reward"] = None
            self.last_projected = True
        else:
            # Action is (virtually) safe
            obs, reward, done, info = self.env.step(action)
            info["projection"] = {"env_reward": reward, "cbf_correction": None, "pun_reward": None, "infeasible": self._infeasible_step}
            self.last_projected = False

        info["projection"]["policy_action"] = action

        return obs, reward, done, info
