import os
import sys

from numpy.core.numeric import Inf
cur_dir = os.path.dirname(__file__)
sys.path.append(os.path.join(cur_dir, '..'))

import numpy as np
import random
import gym
from gym import spaces


import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.spatial import ConvexHull
from matplotlib.patches import Polygon
from pathlib import Path
from multiprocessing import Pool
from functools import partial

from cvxopt import matrix, solvers


class DoubleIntegratorEnv(gym.Env):
    def __init__(self, seed=0, level=10):
        super().__init__()
        self.action_space = spaces.Box(low=-4, high=4, shape=(1,))
        self.observation_space = spaces.Box(low=-4, high=4, shape=(2,))
        self.len = 100
        self.t = 0
        self.level = level

        # Initialize matrices
        self._A = np.array([[1.0, 0.1], [0.0, 1.0]])
        self._B = np.array([[0.0], [0.1]])
        self._D = np.array([[1.0, 0], [0, 1.0], [-1.0, 0], [0, -1.0]])
        self._d = np.array([[2.01], [2.01], [2.01], [2.01]])
        
        # Set random seed
        self.seed(seed)
        self.reset()

    def seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)

    def dynamics(self, s, a):
        s_next = np.dot(self._A, s) + np.dot(self._B, a)
        c = np.dot(self._D, s_next) - self._d
        c = np.maximum(c, 0).mean()  # Average constraint violation
        return s_next, c

    def perturbed_dynamics(self, s, a, eps=0.0, cbf=False):
        A_eps = self._A.copy()
        B_eps = self._B.copy()
        pre_a = a
        if eps != 0.0:
            noise = np.random.normal(0, abs(eps), A_eps.shape)
            A_eps[0, 0] += 1 * noise * A_eps[0, 0]
            A_eps[0, 1] += 1 * abs(noise) * A_eps[0, 1]
            A_eps[1, 0] += 2 * abs(noise) * A_eps[1, 0]
            A_eps[1, 1] += 2 * abs(noise) * A_eps[1, 1]
        if cbf:
            # min-max robust
            a_act, _= self.compute_robust_safe_action_by_CBF(self._A, self._B, self._D, self._d, s, pre_a)
        else:
            a_act = a
        s_next = np.dot(A_eps, s) + np.dot(B_eps, a_act)
        c = np.dot(self._D, s_next) - self._d.reshape(-1, 1)
        c = np.maximum(c, 0).mean() 
        info = {"a_act": a_act, "risk_value": c}
        return s_next, c, info
    
    def calculate_control_invariant_set(self, num_iterations=10, num_state_samples=10, num_action_samples=10, eps=None):
        A_eps = self._A.copy()
        B_eps = self._B.copy()
        if eps != 0.0:
            A_eps[0, 0] += 1 * eps * A_eps[0, 0]
            A_eps[0, 1] += 1 * abs(eps) * A_eps[0, 1]
            A_eps[1, 0] += 2 * abs(eps) * A_eps[1, 0]
            A_eps[1, 1] += 2 * abs(eps) * A_eps[1, 1]
        action_space_samples = np.linspace(self.action_space.low[0], self.action_space.high[0], num_action_samples)
        x = np.linspace(-2, 2, num_state_samples)
        y = np.linspace(-2, 2, num_state_samples)
        X, Y = np.meshgrid(x, y)
        C = np.vstack((X.flatten(), Y.flatten())).T

        for _ in range(num_iterations):
            new_C = []
            for state in C:
                invariant_for_action = False
                for action in action_space_samples:
                    next_state = A_eps @ state.reshape(-1, 1) + B_eps * action
                    next_state = next_state.flatten()  # transfer to 1-dim
                    if np.all(self._D @ next_state.reshape(-1, 1) <= self._d.reshape(-1, 1)):
                        invariant_for_action = True
                        break 
                if invariant_for_action:
                    new_C.append(state)
            C = np.array(new_C) if new_C else C

        return C

    def reset(self, deterministic=False):
        self.t = 0
        self.s = np.array([0, 0]) if deterministic else np.array([np.random.uniform(-2.5, 2.5), 0])
        infos = None
        return self.s, infos

    def step(self, a, eps=0.0, cbf=False):
        self.s, c, infos = self.perturbed_dynamics(self.s, a, eps, cbf=cbf)
        done = self.t >= self.len or np.any(np.abs(self.s) >= 5)
        self.t += 1
        r = self.r_function(self.s[0], self.s[1])
        infos['constraint_violation'] = c
        return self.s, r, done, infos

    def r_function(self, x, y):
        return max(4 - (2 * pow((x - 1.5), 2) + 2 * pow((y + 1.5), 2)), 0) + \
               max(5 - (3 * pow((x + 2.2), 2) + 3 * pow((y + 2.2), 2)), 0) + \
               max(5 - (3 * pow((x - 2.2), 2) + 3 * pow((y - 2.2), 2)), 0) + \
               max(4 - (2 * pow((x + 1.5), 2) + 2 * pow((y - 1.5), 2)), 0) 

    def compute_robust_safe_action_by_CBF(self, A, B, D, d, s, u_nom):
        A = A.reshape(2, 2)
        B = B.reshape(2, 1)
        D = D.reshape(4, 2)
        d = d.reshape(4, 1)
        s = s.reshape(2, 1)
        u_nom = u_nom.reshape(1, 1)
        
        eps = np.linspace(-0.1, 0.1, self.level+1)
        
        def h(x):
            return d - D @ x
        
        def Dh_g(A_eps):
            return -D @ A_eps @ B
        
        def Dh_f(x, A_eps):
            return -D @ A_eps @ (A_eps @ x + B @ u_nom)

        # Objective function parameters for QP (minimize ||u||^2 + 1e5 * eps^2)
        P = matrix(np.diag(np.hstack([np.ones(B.shape[1]), [1e5]])), tc='d')
        q = matrix(np.zeros(B.shape[1] + 1), tc='d')

        alpha = 1
        # CBF constraint
        G1 = np.hstack([-Dh_g(A), -1 * np.ones(d.shape[0]).reshape(-1, 1)]) # (4, 2)
        h1 = alpha * h(s) + Dh_f(s, A)
        
        for i in range(1, self.level+1):
            A_eps = A.copy()
            A_eps[0, 0] += 1 * eps[i] * A[0, 0]
            A_eps[0, 1] += 1 * (eps[i]) * A[0, 1]
            A_eps[1, 0] += 1 * (eps[i]) * A[1, 0]
            A_eps[1, 1] += 1 * (eps[i]) * A[1, 1]
            G_eps = np.hstack([-Dh_g(A_eps), -1 * np.ones(d.shape[0]).reshape(-1, 1)]) # (4, 2)
            h_eps = alpha * h(s) + Dh_f(s, A_eps)
            G1 = np.vstack([G1, G_eps])
            h1 = np.vstack([h1, h_eps])

        # Action constraints
        G2 = np.hstack([np.eye(u_nom.shape[0]).reshape(-1, 1), np.array([0]).reshape(-1, 1)]) # (1, 2)
        G3 = np.hstack([-1 * np.eye(u_nom.shape[0]).reshape(-1, 1), np.array([0]).reshape(-1, 1)])
        h2 = self.action_space.high.reshape(-1, 1) - u_nom
        h3 = u_nom - 1 * self.action_space.low.reshape(-1, 1)
        
        G = matrix(np.vstack([G1, G2, G3]), tc='d')
        h_all = matrix(np.vstack([h1, h2, h3]), tc='d')

        if np.any(h1 <= 0):
            solvers.options['show_progress'] = False
            sol = solvers.qp(P, q, G, h_all)
            u_delta = np.array(sol['x']).flatten()
            u_safe = (u_nom.flatten() + u_delta[:-1]).reshape(-1)
        else:
            u_safe = u_nom.flatten()

        return u_safe, h1.min()



def collect_safety_ratio_parallel(env, num_state_samples=1000, num_action_samples=200):
    states = np.random.uniform(-4, 4, size=(num_state_samples, env.observation_space.shape[0]))
    func = partial(env.calculate_safety_ratio, num_action_samples=num_action_samples)
    
    with Pool(processes=None) as pool: 
        buffer = pool.map(func, states)  
    return buffer

def collect_high_reward_ratio_parallel(env, num_state_samples=1000, num_action_samples=200):
    states = np.random.uniform(-4, 4, size=(num_state_samples, env.observation_space.shape[0]))
    func = partial(env.calculate_high_reward_ratio, num_action_samples=num_action_samples)
    
    with Pool(processes=None) as pool: 
        buffer = pool.map(func, states)  
    return buffer


def plot_convex_hull(ax, control_invariant_set, color, label):
    hull = ConvexHull(control_invariant_set)
    hull_vertices = control_invariant_set[hull.vertices].tolist()
    hull_vertices.append(hull_vertices[0]) 
    poly = Polygon(hull_vertices, closed=True, fill=True, color=color, alpha=0.3, label=label)
    ax.add_patch(poly)
    for simplex in hull.simplices:
        ax.plot(control_invariant_set[simplex, 0], control_invariant_set[simplex, 1], 'k-', color=color)

def plot_environment(ax, X, Y, Z, title, label):
    cs = ax.contourf(X, Y, Z, levels=np.linspace(Z.min(), Z.max(), 100), cmap='viridis')
    fig.colorbar(cs, ax=ax, label=label)
    ax.set_xlabel('x')
    ax.set_ylabel('v')
    ax.set_title(title)
    return cs

def save_fig_with_unique_name(directory, filename_base, extension=".png", dpi=300):
    directory = Path(directory)
    directory.mkdir(parents=True, exist_ok=True)
    filename = f"{filename_base}{extension}"
    filepath = directory / filename
    counter = 1
    while filepath.exists():
        filepath = directory / f"{filename_base}_v{counter}{extension}"
        counter += 1
    plt.savefig(filepath, dpi=dpi)

  
if __name__ == '__main__':
    env = DoubleIntegratorEnv()
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))

    for ep in range(1):
        s = env.reset(deterministic=True)[0]
        s_0, s_1 = [s[0]], [s[1]]
        for t in range(10):
            a = env.action_space.sample()
            s_, r, done, infos = env.step(a, cbf=False)
            c = infos['constraint_violation']
            s_0.append(s_[0])
            s_1.append(s_[1])
            s = s_
            if done:
                print("done")
                break
        
        s_0 = np.array(s_0)
        s_1 = np.array(s_1)
        axs[0].plot(s_0[0], s_1[0], 'o', color='black') 
        axs[0].plot(s_0, s_1, linestyle='-', color='black')
        axs[0].annotate(
            '', 
            xy=(s_0[-1], s_1[-1]), 
            xytext=(s_0[-2], s_1[-2]), 
            arrowprops=dict(arrowstyle="->", color='black', lw=1.5),
        )
        axs[1].plot(s_0[0], s_1[0], 'o', color='black') 
        axs[1].plot(s_0, s_1, linestyle='-', color='black')
        axs[1].annotate(
            '', 
            xy=(s_0[-1], s_1[-1]),  
            xytext=(s_0[-2], s_1[-2]),  
            arrowprops=dict(arrowstyle="->", color='black', lw=1.5),  
        )


    x = np.linspace(-4, 4, 100)
    y = np.linspace(-4, 4, 100)
    X, Y = np.meshgrid(x, y)
    Z = np.minimum.reduce([X+2, 2-X, Y+2, 2-Y])

    R = np.zeros_like(X)
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            R[i, j] = env.r_function(X[i, j], Y[i, j])

    csZ = axs[0].contourf(X, Y, Z, levels=np.linspace(Z.min(), Z.max(), 100), cmap='viridis')
    fig.colorbar(csZ, ax=axs[0], label='Risk Value')
    csR = axs[1].contourf(X, Y, R, levels=np.linspace(R.min(), R.max(), 100), cmap='viridis')
    fig.colorbar(csR, ax=axs[1], label='Reward Value')

    fill_color = ['blue', 'green', 'yellow', 'red']
    for i, eps in enumerate([0.3, 0.5, -0.3, -0.5]):
        control_invariant_set = env.calculate_control_invariant_set(num_iterations=10, num_state_samples=20, num_action_samples=10, eps=eps)
        hull = ConvexHull(control_invariant_set)
        for simplex in hull.simplices:
            axs[0].plot(control_invariant_set[simplex, 0], control_invariant_set[simplex, 1], 'k-', color=fill_color[i])
            axs[1].plot(control_invariant_set[simplex, 0], control_invariant_set[simplex, 1], 'k-', color=fill_color[i])
        
        hull_vertices = control_invariant_set[hull.vertices].tolist()  
        hull_vertices.append(hull_vertices[0]) 
        
        poly = Polygon(hull_vertices, closed=True, fill=True, color=fill_color[i], alpha=0.3, label=f"Safe Invariant Set (eps={eps})")
        axs[0].add_patch(poly)

    

    safe_set = patches.Rectangle((-2,-2),4,4,linewidth=1.5, edgecolor='black',facecolor='none', label='safe set')
    safe_set_copy = patches.Rectangle((-2,-2),4,4,linewidth=1.5, edgecolor='black',facecolor='none', label='safe set')
    axs[0].add_patch(safe_set)   
    axs[1].add_patch(safe_set_copy)   
    axs[0].legend()
    axs[1].legend()
    axs[0].set_xlabel('x')
    axs[0].set_ylabel('v')
    axs[1].set_xlabel('x')
    axs[1].set_ylabel('v')
    axs[0].set_title('Safety')
    axs[1].set_title('Reward')
    
    output_dir = f"../imgs"
    filename = f"envtest_double_intergrator.png"

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    file_path = os.path.join(output_dir, filename)
    if os.path.exists(file_path):
        base, extension = os.path.splitext(filename)
        i = 1
        new_filename = f"{base}_v{i}{extension}"
        new_file_path = os.path.join(output_dir, new_filename)

        while os.path.exists(new_file_path):
            i += 1
            new_filename = f"{base}_v{i}{extension}"
            new_file_path = os.path.join(output_dir, new_filename)
        
        file_path = new_file_path
    plt.savefig(file_path, dpi=300)
