import numpy as np
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt

class AbstractState():
    def __init__(self, lower_bound, upper_bound):
        self.lower_bound = np.array(lower_bound)
        self.upper_bound = np.array(upper_bound)
        
    def reset(self):
        self.current_center = np.random.uniform(self.lower_bound, self.upper_bound)

class Goal(AbstractState):
    def __init__(self, lower_bound, upper_bound, index_range):
        super().__init__(lower_bound, upper_bound)
        self.index_range = index_range
        self.current_state = None

    def predicate(self, state: np.ndarray):
        
        return np.all(state[self.index_range] >= self.lower_bound) and np.all(state[self.index_range] <= self.upper_bound)
    
    def reward(self, state: np.ndarray):
        
        last_state = self.current_state
        self.current_state = state
        if self.predicate(state):
            return 10
        else:
            return np.linalg.norm(last_state[self.index_range] - self.current_center) - np.linalg.norm(state[self.index_range] - self.current_center)
    
    def plot(self):
        data = np.random.uniform(low = self.lower_bound, high = self.upper_bound, size=(100, 2))
        plt.plot(data[:,0], data[:,1], 'bo')
        
class NewGoal(Goal):
    
    def __init__(self, lower_bound, upper_bound, index_range, hull):
        super().__init__(lower_bound, upper_bound, index_range)
        self.hull = hull
    
    def predicate(self, state: np.ndarray):
        # new_points = np.vstack([self.hull.points, state[self.index_range].reshape(1, -1)])
        # new_hull = ConvexHull(new_points)
        return np.all(np.dot(self.hull.equations[:, :-1], state[self.index_range]) <= -self.hull.equations[:, -1])
    
    def reset(self):
        
        points = self.hull.points
        n_points = points.shape[0]


        while True:
            # Generate random coefficients for the convex combination
            coeffs = np.random.rand(n_points)
            coeffs /= coeffs.sum()

            # Compute the point using the convex combination
            point = np.dot(coeffs, points)

            # Check if the point is inside the hull
            if np.all(np.dot(self.hull.equations[:, :-1], point) <= -self.hull.equations[:, -1]):
                self.current_center = point
                break
            
    
    def plot(self):
        plt.plot(self.hull.points[:,0], self.hull.points[:,1], 'o')
        for simplex in self.hull.simplices:
            plt.plot(self.hull.points[simplex, 0], self.hull.points[simplex, 1], 'k-')

class ModifiedGoal(Goal):
    def __init__(self, lower_bound: np.ndarray, upper_bound: float, index_range, hull, reachable=False):
        super().__init__(lower_bound=lower_bound, upper_bound=upper_bound, index_range=index_range)
        self.hull = hull
        self.reachable = reachable

    def reset(self):
        while True:
            super().reset()
            if self.predicate(self.current_center):
                break
    
            
    def predicate(self, point):
        new_points = np.vstack([self.hull.vertices, np.array(point).reshape(1, -1)])
        new_hull = ConvexHull(new_points)
        return list(new_hull.vertices) == list(self.hull.vertices)

