"""
A robot that can exert force in cardinal directions. The robot's goal is to
reach the origin and it experiences zero-mean Gaussian Noise and air resistance
proportional to its velocity.

State representation is (x, vx, y, vy, closest_distance_to_racing_line).

Action representation is (fx, fy), and mass is assumed to be 1.
"""

import os
import pickle

import os.path as osp
import numpy as np
from gym import Env
from gym import utils
from gym.spaces import Box
import torch

from .pointbot_const import *
from .gap import get_gap

def process_action(a):
    return np.clip(a, -MAX_FORCE, MAX_FORCE)

def lqr_gains(A, B, Q, R, T):
    Ps = [Q]
    Ks = []
    for t in range(T):
        P = Ps[-1]
        Ps.append(Q + A.T.dot(P).dot(A) - A.T.dot(P).dot(B)
            .dot(np.linalg.inv(R + B.T.dot(P).dot(B))).dot(B.T).dot(P).dot(A))
    Ps.reverse()
    for t in range(T):
        Ks.append(-np.linalg.inv(R + B.T.dot(Ps[t+1]).dot(B)).dot(B.T).dot(P).dot(A))
    return Ks, Ps


class PointBot(Env, utils.EzPickle):

    def __init__(self):
        utils.EzPickle.__init__(self)
        self.hist = self.cost = self.done = self.time = self.state = None
        self.A = np.eye(4)
        self.A[2,3] = self.A[0,1] = 1
        self.A[1,1] = self.A[3,3] = 1 - AIR_RESIST
        self.B = np.array([[0,0], [1,0], [0,0], [0,1]])
        self.horizon = HORIZON
        self.action_space = Box(-np.ones(2) * MAX_FORCE, np.ones(2) * MAX_FORCE)
        self.observation_space = Box(-np.ones(4) * np.float('inf'), np.ones(4) * np.float('inf'))
        # self.obstacle = ComplexObstacle([[-30, -20], [-20, 20]])
        self.start_state = START_STATE
        self.racing_line = None


        self.noise_scale = NOISE_SCALE
        self.noise_on_init_state = NOISE_ON_INIT_STATE

    def set_mode(self, mode):
        self.mode = mode
        self.obstacle = OBSTACLE[mode]
        if self.mode == 1:
            self.start_state = [-100, 0, 0, 0]

    def set_maneuver(self, maneuver, torch_device):
        self.maneuver = maneuver
        self.start_state = MANEUVERS[maneuver]['start_state']
        self.racing_line = MANEUVERS[maneuver]['racing_line']
        self.racing_line_torch =  torch.tensor(self.racing_line, device=torch_device, dtype=torch.float)

    def process_action(self, state, action):
        return action

    def step(self, a):
        a = a.reshape(-1)
        a = process_action(a)
        next_state = self._next_state(self.state, a)
        cur_cost = self.step_cost(next_state.reshape(1,-1), a.reshape(1,-1))
        self.cost.append(cur_cost)
        if not self.obstacle(self.state[:4]): # remove gap
            self.state = next_state
        self.time += 1
        self.hist.append(self.state)
        # ignore task horizon
        #self.done = HORIZON <= self.time

        # if in line follower don't add cost on end of episode
        if not LINE_FOLLOWER_MODE:
            if self.done and not self.is_stable(self.state):
                # TODO remove this after check
                print("add cost!!")
                self.cost[-1] += FAILURE_COST
                cur_cost += FAILURE_COST
        return self.state, cur_cost[0], self.done, {}

    def reset(self):
        self.state = self.start_state + np.random.randn(4)

        # uncomment to remove initial noise
        if not self.noise_on_init_state:
            self.state = self.start_state + np.random.randn(4) * 0.0
        # uncomment clear speed noise
        self.state *= np.array([1.0,0.0,1.0,0.0])

        self.state = self.append_gap_to_state(self.state)

        self.time = 0
        self.cost = []
        self.done = False
        self.hist = [self.state]
        return self.state

    def _next_state(self, s, a):
        s = s[:4] # remove gap
        s = self.A.dot(s) + self.B.dot(a) + self.noise_scale * np.random.randn(len(s))
        return self.append_gap_to_state(s)

    def append_gap_to_state(self, s):
        x = s.reshape(1,-1)[:, 0:1]
        y = s.reshape(1,-1)[:, 2:3]
        points = np.hstack([x, y])
        gap, _ = get_gap(points, self.racing_line)
        return np.concatenate([s, gap])

    def step_cost(self, s, a):
        if LINE_FOLLOWER_MODE:
            GAP_TARGET = 0.0
            GAP_CONST = 1.0 / 2.5
            SPEED_CONST = 0.1
            TOP_SPEED = 300.
            speed_x = s[:, 1:2]
            speed_y = s[:, 3:4]
            gap = s[:, 4:5]

            if isinstance(s, np.ndarray):
                # speed
                speed = (np.sqrt(np.square(speed_x) + np.square(speed_y)))  / SPEED_CONST
                r = speed * (1. - (np.abs(gap) * GAP_CONST))
                r = r.reshape(-1)  # [N, 1] -> [N]
                return -r # reward to cost
            else:
                # speed
                speed = (torch.sqrt(torch.square(speed_x) + torch.square(speed_y)))  / SPEED_CONST
                r = speed * (1. - (torch.abs(gap) * GAP_CONST))
                r = r.reshape(-1)  # [N, 1] -> [N]
                return -r # reward to cost
        elif HARD_MODE:
            return int(np.linalg.norm(np.subtract(GOAL_STATE, s)) > GOAL_THRESH) + self.obstacle(s)
        else:
            return np.linalg.norm(np.subtract(GOAL_STATE, s))

    def collision_cost(self, obs):
        return self.obstacle(obs)

    def end(self):
        pass

    def values(self):
        return np.cumsum(np.array(self.cost)[::-1])[::-1]

    def sample(self):
        return np.random.random(2) * 2 * MAX_FORCE - MAX_FORCE

    def plot_trajectory(self, states=None):
        if states == None:
            states = self.hist
        states = np.array(states)
        plt.scatter(states[:,0], states[:,2])
        plt.show()

    # Returns whether a state is stable or not
    def is_stable(self, s):
        return np.linalg.norm(np.subtract(GOAL_STATE, s)) <= GOAL_THRESH

    def teacher(self, sess=None):
        return PointBotTeacher()

class PointBotTeacher(object):

    def __init__(self):
        self.env = PointBot()
        self.Ks, self.Ps = lqr_gains(self.env.A, self.env.B, np.eye(4), 50 * np.eye(2), HORIZON)
        self.demonstrations = []
        self.outdir = "data/pointbot"

    def get_rollout(self):
        obs = self.env.reset()
        O, A, cost_sum, costs = [obs], [], 0, []
        noise_std = 0.2
        for i in range(HORIZON):
            if self.env.mode == 1:
                noise_idx = np.random.randint(int(HORIZON * 2 / 3))
                if i < HORIZON / 2:
                    action = [0.1, 0.1]
                else:
                    action = self._expert_control(obs, i)
            else:
                noise_idx = np.random.randint(int(HORIZON))
                if i < HORIZON / 4:
                    action = [0.1, 0.25]
                elif i < HORIZON / 2:
                    action = [0.4, 0.]
                elif i < HORIZON / 3 * 2:
                    action = [0, -0.5]
                else:
                    action = self._expert_control(obs, i)

            if i < noise_idx:
                action = (np.array(action) +  np.random.normal(0, noise_std, self.env.action_space.shape[0])).tolist()

            A.append(action)
            obs, cost, done, info = self.env.step(action)
            O.append(obs)
            cost_sum += cost
            costs.append(cost)
            if done:
                break
        costs = np.array(costs)

        values = np.cumsum(costs[::-1])[::-1]
        if self.env.is_stable(obs):
            stabilizable_obs = O
        else:
            stabilizable_obs = []
            return self.get_rollout()

        return {
            "obs": np.array(O),
            "ac": np.array(A),
            "cost_sum": cost_sum,
            "costs": costs,
            "values": values,
            "stabilizable_obs" : stabilizable_obs
        }

    def _get_gain(self, t):
        return self.Ks[t]

    def _expert_control(self, s, t):
        return self._get_gain(t).dot(s)
