from typing import Tuple

import numpy as np
from pypoman import compute_polytope_halfspaces, compute_polytope_vertices


class SafeRegion:
    """
    :param A: Half-space representation (Ax <= b)
    :param b: Half-space representation (Ax <= b)
    :param vertices: Vertex representation of the safe region
    """

    def __init__(self, A=None, b=None, vertices=None, seed=None):

        if vertices is None and (A is None or b is None):
            raise ValueError("Missing representation")

        if vertices is not None:
            if A is None or b is None:
                self._A, self._b = compute_polytope_halfspaces(vertices)
            else:
                self._A = A
                self._b = b
            self._v = vertices
        else:
            self._A = A
            self._b = b
            self._v = compute_polytope_vertices(self._A, self._b)

        self._seed = seed
        self._rng = np.random.default_rng(self._seed)

    def contains(self, state, bound=None):
        """Check if state lies within safe region + bound.

        :param state: state to check
        :param bound: bound to add to safe region.

        :return: True iff state is inside safe region.
            If multiple states are given, return True iff all states are inside the safe region.
        """
        if bound is None:
            return state in self
        else:
            if len(state.shape) == 1:
                return np.all(np.matmul(self._A, state) <= (self._b + bound))
            elif len(state.shape) == 2:
                b = self._b + bound
                return np.all(
                    np.matmul(self._A, state.T) <= np.repeat(b[:, np.newaxis], state.shape[0], axis=-1))
            else:
                raise ValueError("Invalid state shape")

    def euclidean_dist_to_safe_region(self, state):
        """Check distance to safe region, 0 in case of containment.

        :param state: state to check

        :return: Eucledian ddistance of all state dimensions
        """
        if len(state.shape) == 1:
            return np.linalg.norm(np.max(np.matmul(self._A, state) - self._b,0))
        elif len(state.shape) == 2:
            return np.linalg.norm(np.max(
                np.matmul(self._A, state.T) - np.repeat(self._b[:, np.newaxis], state.shape[0], axis=-1),0))
        else:
            raise ValueError("Invalid state shape")

    def __contains__(self, state):
        """Check if state lies within safe region.

        :return: True iff state is inside safe region.
            If multiple states are given, return True iff all states are inside the safe region.
        """
        state = np.array(state)
        if len(state.shape) == 1:
            return np.all(np.matmul(self._A, state) <= self._b)
        elif len(state.shape) == 2:
            return np.all(
                np.matmul(self._A, state.T) <= np.repeat(self._b[:, np.newaxis], state.shape[0], axis=-1))
        else:
            raise ValueError("Invalid state shape")

    @property
    def rng(self):
        return self._seed

    @rng.setter
    def rng(self, seed):
        self._seed = seed
        self._rng = np.random.default_rng(self._seed)

    @property
    def vertices(self):
        """
        :return: vertex representation of the safe region
        """
        return self._v

    @vertices.setter
    def vertices(self, vertices) -> None:
        """
        :param vertices: vertex representation of the safe region
        """
        self._v = vertices

    @property
    def halfspaces(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        :return: half-space representation (A,b) of the safe region
        """
        return self._A, self._b

    @halfspaces.setter
    def halfspaces(self, Ab: Tuple[np.ndarray, np.ndarray]) -> None:
        """
        :param Ab: half-space representation (A,b) of the safe region
        """
        self._A, self._b = Ab

    @classmethod
    def compute_safe_region(cls):
        """
        :return: safe region
        """
        raise NotImplementedError

    def sample(self):
        """
        :return: sample state within safe region
        """
        raise NotImplementedError
