import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from enum import Enum
import matplotlib.pyplot as plt
from sklearn.kernel_ridge import KernelRidge
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel, RBF
from collections import defaultdict

def d(x, y):
    return np.linalg.norm(x - y)

class RobotNavEnv(gym.Env):

    metadata = {'render.modes': ['human']}

    def __init__(self, r=0.1):
        # spaces
        self.action_space = spaces.Discrete(4, start=1)
        self.observation_space = spaces.Box(low=0, high=10, shape=(2,), dtype=np.float32)

        # parameters
        self.start_pos = np.array([0., 0.])
        self.goal_pos = np.array([10., 0.])
        self.delta = 3
        self.goal_reward = 100
        self.H = 100
        self.r = r

        # obstacles
        self.obs_pos = [np.array([5., 0.])]
        self.obs_a = [0.5]

        # episode
        self.h = 0
        self.pos = np.copy(self.start_pos)
        self.total_reward = 0.

    def reset(self):
        self.pos = np.copy(self.start_pos)
        self.h = 0
        self.total_reward = 0.
        return self._get_observation()

    def step(self, action):
        direction_map = {
            1: np.array([0., 1.]),  # up
            2: np.array([0., -1.]), # down
            3: np.array([-1., 0.]), # left
            4: np.array([1., 0.])   # right
        }
        direction = direction_map[action]
        perturbation = self.r * np.random.uniform(-np.pi, np.pi)
        cos = np.cos(perturbation)
        sin = np.sin(perturbation)
        rotation_mat = np.array([[cos, sin], [-sin, cos]])
        direction = rotation_mat @ direction
        self.h += 1
        self.pos += direction
        self.pos = np.clip(self.pos, 0., 10.)

        observation = self._get_observation()
        step_reward = self._calculate_reward(action)
        done = self._done()
        info = None

        return observation, step_reward, done, info

    def _get_observation(self):
        return np.copy(self.pos)
        
    def close(self):
        plt.close()

    def _done(self):
        return (d(self.pos, self.goal_pos) < self.delta) or (self.h >= self.H)

    def _calculate_reward(self, pos=None):
        if pos is None:
            pos = self.pos
        if d(pos, self.goal_pos) < self.delta:
            return self.goal_reward
        reward = 0
        for obs_pos, obs_a in zip(self.obs_pos, self.obs_a):
            reward -= 5 * np.exp(-obs_a * d(pos, obs_pos))
        return reward

    def sample_next_states(self, s, action, num_samples=500):
        direction_map = {
            1: np.array([0., 1.]),  # up
            2: np.array([0., -1.]), # down
            3: np.array([-1., 0.]), # left
            4: np.array([1., 0.])   # right
        }
        direction = direction_map[action]
        next_states = []
        for _ in range(num_samples):
            perturbation = self.r * np.random.uniform(-np.pi, np.pi)
            cos = np.cos(perturbation)
            sin = np.sin(perturbation)
            rotation_mat = np.array([[cos, sin], [-sin, cos]])
            direction = rotation_mat @ direction
            next_states.append(s + direction)
        return next_states

T = 600
H = 30

def get_greedy_action(Qhat, s, l, include_bonus):
    Qs = []
    for a in range(1, 5):
        a = np.array([a])
        x = np.concatenate([s, a], axis=0)
        x = x[np.newaxis, :]
        q, b = Qhat.predict(x, return_std=True)
        if include_bonus:
            q += T * b / np.sqrt(l)
        Qs.append(q)
    Qs = np.array(Qs)
    return int(np.argmax(Qs)) + 1, np.max(Qs)

def CVaR_estimator(Qhat, s, a, l, alpha=0.3):
    a = int(a)
    next_states = env.sample_next_states(s, a)
    V = []
    for s in next_states:
        # _, v = get_greedy_action(Qhat, s, l, include_bonus=False)
        v = env._calculate_reward(s)
        V.append(v)
    V = np.array(V)
    V.sort()
    fraction = int(len(next_states) * alpha)
    return np.average(V[:fraction])


def CVaR_VI(env, gamma=1, l=0.1, alpha=1.0):
    nA = env.action_space.n
    kernel = RBF()
    X = []
    r = []
    rewards = []
    success = 0
    trajectories = defaultdict(list)
    Qhat = {}
    for h in range(H):
        Qhat[h] = GaussianProcessRegressor(kernel=kernel, alpha=l, n_restarts_optimizer=1)
    for t in range(T):
        s = env.reset()
        if t > 0:
            for h in range(H - 1, -1, -1):
                if len(trajectories[h]) == 0:
                    continue
                y = []
                X = []
                for (s, a, r, done) in trajectories[h]:
                    X.append(np.concatenate([s, a], axis=0))
                    if done:
                        y.append(r)
                    else:
                        CVaR_est = CVaR_estimator(Qhat[h+1], s, a, l, alpha=alpha)
                        y.append(r + CVaR_est)
                Qhat[h].fit(np.stack(X), np.array(y))
        s = env.reset()
        episode_reward = 0
        for h in range(H-1):
            a, _ = get_greedy_action(Qhat[h], s, l, include_bonus=True)
            s_, r, done, _ = env.step(a)
            episode_reward += r
            if r == env.goal_reward:
                success += 1
                break
            a = np.array([a])
            trajectories[h].append([s, a, r, done])
            s = s_
        rewards.append(episode_reward)
    print(rewards)
    print(success)
    return Qhat

env = RobotNavEnv(r=0.01)
Q = CVaR_VI(env, alpha=1.0)