from abc import abstractmethod
from typing import Tuple
import os
import pickle


import numpy as np
import scipy as sc

from pypoman import compute_polytope_halfspaces, compute_polytope_vertices
import cvxpy as cp

from action_masking.util.sets import Zonotope
from action_masking.util.util import normalize_Ab


class SafeRegion:
    """
    Base class for safe regions.
    """

    def __init__(self, x_goal: np.ndarray, x_lim_low: np.ndarray, x_lim_high: np.ndarray, seed=None):
        """
        Args:
        x_goal (np.ndarray): Goal state.
        x_lim_low (np.ndarray): Lower bound of the state space.
        x_lim_high (np.ndarray): Upper bound of the state space.
        seed (int, optional): Random seed. Defaults to None.
        """
        self.x_goal = x_goal
        self.x_lim_low = x_lim_low
        self.x_lim_high = x_lim_high
        self._seed = seed
        self._rng = np.random.default_rng(self._seed)

    @property
    def rng(self):
        return self._seed

    @rng.setter
    def rng(self, seed):
        self._seed = seed
        self._rng = np.random.default_rng(self._seed)

    @classmethod
    def compute_safe_region(cls):
        """
        :return: safe region
        """
        raise NotImplementedError

    def sample(self):
        """Sample a point from the control invariant set."""
        sample = None
        while sample is None:
            sample = self._rng.uniform(self.x_lim_low, self.x_lim_high)
            try:
                if sample not in self:
                    sample = None
            # We need this because the __contains__ method of the zonotope class uses a solver that sometimes fails randomly. 
            except ValueError:
                print("Contains solver failed, retrying...")
                sample = None
        return sample

    def sample_ctrl(self, x):
        """Sample a control input from the control invariant set.

        Args:
            x (np.ndarray): State of the system.
        Returns:
            np.ndarray: Control input.
        """
        u_ctrl = None
        b_eq_ctrl = x - self.x_goal
        sol = sc.optimize.linprog(
            self.f_ctrl, A_ub=self.A_ctrl, b_ub=self.b_ctrl, A_eq=self.A_eq_ctrl, b_eq=b_eq_ctrl, bounds=self.x_bounds
        )
        if sol.status == 0 and sol.fun <= 1 + 1e-6:
            x_para = sol.x[1:]
            u_ctrl = self.G_center + self.G_ctrl @ x_para
        elif sol.status == 0 or sol.status == 4:
            x_para = sol.x[1:]
            u_ctrl = self.G_center + self.G_ctrl @ x_para
            print("Solver status is {}".format(sol.status))
            print("Check state {} and control {}".format(x, u_ctrl))
        else:
            raise ValueError("Error in fail-safe planner, no solution found for state {}".format(x))
        return u_ctrl

    @abstractmethod
    def __contains__(self, state):
        """
        Args:
            state: state
        Returns: True iff state is inside safe region
        """
        raise NotImplementedError


class ControlInvariantSetPolytope(SafeRegion):
    """Control invariant set based on half-spaces."""

    def __init__(
        self,
        polytope_csv: str = None,
        vertices_csv: str = None,
        A_pkl: str = None,
        b_pkl: str = None,
        S_ctrl_csv: str = None,
        S_RCI_csv: str = None,
        x_goal: np.ndarray = None,
        x_lim_low: np.ndarray = None,
        x_lim_high: np.ndarray = None,
        seed=None,
    ):
        """
        Initialize the control invariant set.
        Either set polytope_csv or (A_pkl and b_pkl and vertices_csv).
        Args:
            polytope_csv (str): Path to the CSV file containing the half-space representation of the safe region.
            vertices_csv(str): Path to the CSV file containing vertex representation of the safe region.
            A_pkl (str): Path to the pickle file containing the A matrix.
            b_pkl (str): Path to the pickle file containing the b vector.
            S_ctrl_csv (str): Path to the CSV file containing the S_ctrl.
            S_RCI_csv (str): Path to the CSV file containing the S_RCI.
            x_lim_low (np.ndarray): Lower bound of the state space.
            x_lim_high (np.ndarray): Upper bound of the state space.
            x_goal (np.ndarray): Goal state.
            seed (int, optional): Random seed. Defaults to None.
        """
        root = os.path.dirname(os.path.abspath(__file__)) + "/../../"
        vertices = None

        if polytope_csv is not None and A_pkl is not None:
            raise ValueError("Only one representation can be given. Either polytope or (A and b and vertices).")
        elif polytope_csv is None and A_pkl is None:
            raise ValueError("One representation must be given. Either polytope or (A and b and vertices).")
        elif polytope_csv is not None:
            polytope = np.genfromtxt(root + polytope_csv, delimiter=",")
            A = polytope[:, :-1]
            b = polytope[:, -1]
            A_normalized, b_normalized = normalize_Ab(A, b)
        elif A_pkl is not None:
            A_normalized = pickle.load(open(root + A_pkl, "rb"))
            b_normalized = pickle.load(open(root + b_pkl, "rb"))
            vertices = np.genfromtxt(root + vertices_csv, delimiter=",")

        # polytope = np.genfromtxt(root + 'matlab/polytope_LongQuadrotor.csv', delimiter=',')
        # A = polytope[:, :-1]
        # b = polytope[:, -1] + A @ (self.x_goal-self.x_halfspace)
        # A, b = remove_redundant_constraints(A,b)

        if S_ctrl_csv is None or S_RCI_csv is None:
            raise ValueError("Missing robust controller information")

        # Polytope init
        if vertices is None and (A_normalized is None or b_normalized is None):
            raise ValueError("Missing representation")

        if vertices is not None:
            if A_normalized is None or b_normalized is None:
                self._A, self._b = compute_polytope_halfspaces(vertices)
            else:
                self._A = A_normalized
                self._b = b_normalized
            self._v = vertices
        else:
            self._A = A_normalized
            self._b = b_normalized
            self._v = compute_polytope_vertices(self._A, self._b)

        # Init base class
        super(ControlInvariantSetPolytope, self).__init__(x_goal=x_goal, x_lim_low=x_lim_low, x_lim_high=x_lim_high, seed=seed)

        zonotope_S_ctrl = np.genfromtxt(root + S_ctrl_csv, delimiter=",")
        self.G_ctrl = zonotope_S_ctrl[:, 1:]
        zonotope_S_RCI = np.genfromtxt(root + S_RCI_csv, delimiter=",")
        self.G_S = zonotope_S_RCI[:, 1:]
        self.G_center = zonotope_S_ctrl[:, 0]
        self.A_ctrl = np.c_[-np.ones(2 * self.G_S.shape[1]), np.vstack((np.eye(self.G_S.shape[1]), -np.eye(self.G_S.shape[1])))]
        self.b_ctrl = np.zeros(2 * self.G_S.shape[1])
        self.A_eq_ctrl = np.c_[np.zeros(self.G_S.shape[0]), self.G_S]
        self.f_ctrl = np.zeros(self.G_S.shape[1] + 1)
        self.f_ctrl[0] = 1
        self.x_bounds = (-1, 1)

    def contains(self, state, bound=None):
        """Check if state lies within safe region + bound.

        :param state: state to check
        :param bound: bound to add to safe region.

        :return: True iff state is inside safe region.
            If multiple states are given, return True iff all states are inside the safe region.
        """
        if bound is None:
            return state in self
        else:
            if len(state.shape) == 1:
                return np.all(np.matmul(self._A, state) <= (self._b + bound))
            elif len(state.shape) == 2:
                b = self._b + bound
                return np.all(np.matmul(self._A, state.T) <= np.repeat(b[:, np.newaxis], state.shape[0], axis=-1))
            else:
                raise ValueError("Invalid state shape")

    def euclidean_dist_to_safe_region(self, state):
        """Check distance to safe region, 0 in case of containment.

        :param state: state to check

        :return: Eucledian distance of all state dimensions
        """
        if len(state.shape) == 1:
            return np.linalg.norm(np.max(np.matmul(self._A, state) - self._b, 0))
        elif len(state.shape) == 2:
            return np.linalg.norm(
                np.max(np.matmul(self._A, state.T) - np.repeat(self._b[:, np.newaxis], state.shape[0], axis=-1), 0)
            )
        else:
            raise ValueError("Invalid state shape")

    def __contains__(self, state):
        """Check if state lies within safe region.

        :return: True iff state is inside safe region.
            If multiple states are given, return True iff all states are inside the safe region.
        """
        state = np.array(state)
        if len(state.shape) == 1:
            return np.all(np.matmul(self._A, state) <= self._b)
        elif len(state.shape) == 2:
            return np.all(np.matmul(self._A, state.T) <= np.repeat(self._b[:, np.newaxis], state.shape[0], axis=-1))
        else:
            raise ValueError("Invalid state shape")

    @property
    def vertices(self):
        """
        :return: vertex representation of the safe region
        """
        return self._v

    @vertices.setter
    def vertices(self, vertices) -> None:
        """
        :param vertices: vertex representation of the safe region
        """
        self._v = vertices

    @property
    def polytope(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        :return: half-space/polytope representation (A,b) of the safe region
        """
        return self._A, self._b

    @polytope.setter
    def polytope(self, Ab: Tuple[np.ndarray, np.ndarray]) -> None:
        """
        :param Ab: half-space/polytope representation (A,b) of the safe region
        """
        self._A, self._b = Ab


class ControlInvariantSetZonotope(SafeRegion):
    """Control invariant set based on Zonotopes."""

    def __init__(
        self,
        S_ctrl_csv: str = None,
        S_RCI_csv: str = None,
        x_goal: np.ndarray = None,
        x_lim_low: np.ndarray = None,
        x_lim_high: np.ndarray = None,
        seed=None,
    ):
        """
        Initialize the control invariant set.
        Args:
            S_ctrl_csv (str): Path to the CSV file containing the S_ctrl.
            S_RCI_csv (str): Path to the CSV file containing the S_RCI.
            x_goal (np.ndarray): Goal state.
            x_lim_low (np.ndarray): Lower bound of the state space.
            x_lim_high (np.ndarray): Upper bound of the state space.
            seed (int, optional): Random seed. Defaults to None.
        """
        root = os.path.dirname(os.path.abspath(__file__)) + "/../../"

        # Init base class
        super(ControlInvariantSetZonotope, self).__init__(x_goal=x_goal, x_lim_low=x_lim_low, x_lim_high=x_lim_high, seed=seed)

        zonotope_S_ctrl = np.genfromtxt(root + S_ctrl_csv, delimiter=",")
        self.G_ctrl = zonotope_S_ctrl[:, 1:]
        self.G_center = zonotope_S_ctrl[:, 0]
        zonotope_S_RCI = np.genfromtxt(root + S_RCI_csv, delimiter=",")
        self.G_S = zonotope_S_RCI[:, 1:]
        self.c_S = zonotope_S_RCI[:, 0]
        self.RCI_zonotope = Zonotope(G=self.G_S, c=self.c_S.reshape((len(zonotope_S_RCI[:, 0]), 1)))
        self.A_ctrl = np.c_[-np.ones(2 * self.G_S.shape[1]), np.vstack((np.eye(self.G_S.shape[1]), -np.eye(self.G_S.shape[1])))]
        self.b_ctrl = np.zeros(2 * self.G_S.shape[1])
        self.A_eq_ctrl = np.c_[np.zeros(self.G_S.shape[0]), self.G_S]
        self.f_ctrl = np.zeros(self.G_S.shape[1] + 1)
        self.f_ctrl[0] = 1
        self.x_bounds = (-1, 1)

    def __contains__(self, state):
        """
        Check if state lies within safe region.
        Args:
            state (np.ndarray): State of the system.
        Returns:
            True iff state is inside safe region.
        """
        return self.RCI_zonotope.contains_point(np.array(state.reshape(-1, 1)))

class ControlInvariantSetZonotopeFelixController(SafeRegion):
    """Control invariant set based on Zonotopes."""

    def __init__(
        self,
        S_ctrl_csv: str = None,
        S_RCI_csv: str = None,
        K_ctrl_csv : str = None,
        N_ctrl : int = None,
        x_goal: np.ndarray = None,
        x_lim_low: np.ndarray = None,
        x_lim_high: np.ndarray = None,
        seed=None,
    ):
        """
        Initialize the control invariant set.
        Args:
            S_ctrl_csv (str): Path to the CSV file containing the S_ctrl.
            S_RCI_csv (str): Path to the CSV file containing the S_RCI.
            x_goal (np.ndarray): Goal state.
            x_lim_low (np.ndarray): Lower bound of the state space.
            x_lim_high (np.ndarray): Upper bound of the state space.
            seed (int, optional): Random seed. Defaults to None.
        """
        root = os.path.dirname(os.path.abspath(__file__)) + "/../../"

        # Init base class
        super(ControlInvariantSetZonotopeFelixController, self).__init__(x_goal=x_goal, x_lim_low=x_lim_low, x_lim_high=x_lim_high, seed=seed)

        zonotope_S_RCI = np.genfromtxt(root + S_RCI_csv, delimiter=",")
        self.G_S = zonotope_S_RCI[:, 1:]
        self.c_S = zonotope_S_RCI[:, 0]
        self.RCI_zonotope = Zonotope(G=self.G_S, c=self.c_S.reshape((len(zonotope_S_RCI[:, 0]), 1)))


        zonotope_S_ctrl = np.genfromtxt(root + S_ctrl_csv, delimiter=",")
        N_u = int(zonotope_S_ctrl.shape[0] / N_ctrl)
        self.Zonotopes_ctrl = []
        for t in range(N_ctrl):
            G_ctrl = zonotope_S_ctrl[N_u*t:(N_u*(t+1)), 1:]
            c_ctrl = zonotope_S_ctrl[N_u*t:(N_u*(t+1)), 0]
            zonotope_ctrl_t = Zonotope(G=G_ctrl, c=c_ctrl.reshape((N_u, 1)))
            self.Zonotopes_ctrl.append(zonotope_ctrl_t)

        self.K_gains_ctrl = np.genfromtxt(root + K_ctrl_csv, delimiter=",")
        self.k_ctrl = 0
        self.gamma = None

    def __contains__(self, state):
        """
        Check if state lies within safe region.
        Args:
            state (np.ndarray): State of the system.
        Returns:
            True iff state is inside safe region.
        """
        return self.RCI_zonotope.contains_point(np.array(state.reshape(-1, 1)))

    def sample_ctrl(self, x):
        if self.k_ctrl == 0:
            gamma = self.RCI_zonotope.get_generator_parametrization_for_point(x.reshape((len(x),1)))
            self.gamma = gamma
        # Todo: check that k_ctrl is not larger than 1-N_ctrl
        if self.k_ctrl >= len(self.Zonotopes_ctrl) - 1:
            self.k_ctrl = 0

        u = np.squeeze(self.Zonotopes_ctrl[self.k_ctrl].G @ self.gamma + self.Zonotopes_ctrl[self.k_ctrl].c) + self.K_gains_ctrl @ x
        return u

    def sample(self):
        """Sample a point from the control invariant set."""
        sample = None
        while sample is None:
            sample = self._rng.uniform(self.x_lim_low/10, self.x_lim_high/10)
            try:
                if sample not in self:
                    sample = None
            # We need this because the __contains__ method of the zonotope class uses a solver that sometimes fails randomly.
            except ValueError:
                print("Contains solver failed, retrying...")
                sample = None
        return sample


