"""
Module for studying a mountain car problem with continuous action.
Fitted Q-iteration (FQI) with linear function approximation is used
to conduct (offline) reinforcement learning.

-----------------------------------------------------------------

class MountainCar:  class for conducting FQI for mountain car.


Goal is to conduct fitted Q-iteration (FQI) to find the optimal
    policy that maximizes the expected cumulative return.

Functions:  generate_next_state: given state-action pairs, generate
                     next states and immediate reward.
            simulate_data: generate dataset of a given sample size.
            compute_feature: compute the feature mapping.
                comp_feature_p: compute the feature mapping for position p.
                comp_feature_v: compute the feature mapping for velocity v.
                comp_feature_f: compute the feature mapping for force f.
            compute_val_fun: calculate the value function and greedy policy
                             associated with a given Q-function.
            evaluate_value: evaluate the value of a given policy.
            conduct_FQI: conduct FQI with linear function approximation to
                         solve the reinforcement learning problem.
"""

import numpy as np
import numpy.linalg as npl
import numpy.random as npr
import math as math
import scipy.linalg as scilin
import collections as col

import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
# Axes3D import has side effects, it enables using projection='3d' in add_subplot
import matplotlib.pyplot as plt


class MountainCar:
    """
    Class for studying mountain car with continuous action.
    """

    def __init__(self, n, dim_p = 50, dim_v = 15, dim_f = 4):

        """
        Update parameters.

        INPUTS: (dim_p, dim_v, dim_f)

        ACTIONS:
        """

        self.n = n

        self.min_action = -1.0
        self.max_action = 1.0
        self.min_position = -1.2 + 1e-3
        self.max_position = 0.6 - 1e-3
        self.max_speed = 0.07
        self.power = 0.0015
        self.goal_position = 0.45

        self.sigma_p = 0.01
        self.sigma_v = 0.0025

        self.dim_p = dim_p
        self.dim_v = dim_v
        self.dim_f = dim_f
        self.d = dim_p * dim_v * dim_f

        self.discount = 0.97


    #################################################################

    def generate_next_state(self, p, v, f):
        """
        Given current state-action pairs (p, v, f),
        generate next states (p_new, v_new) and immediate rewards.

        INPUTS: p == n-dimensional vector. Current positions.
                v == n-dimensional vector. Current velocities.
                f == n-dimensional vector. Current forces.

        ACTIONS: p_new == n-dimensional vector. Next positions.
                 v_new == n-dimensional vector. Next velocities.
                 reward == n-dimensional vector. Immediate rewards.
                 The dynamic equations are as follows:
                    mountain: sin(3p)/3 + 0.025/((p+1.2)(0.6-p));
                    slope = derivative of Mountain;
                    v_new = bound(v + f * power - slope + noise(sigma_v));
                    p_new = bound(p + v_new + noise(sigma_p));
                    reward = - 0.1 * f^2
                             + 100 * positive_part(p - goal_position)^2.
        """

        n = p.size

        min_position, max_position = self.min_position, self.max_position
        max_speed = self.max_speed
        goal_position = self.goal_position
        power = self.power
        sigma_p, sigma_v = self.sigma_p, self.sigma_v

        slope = np.cos(3*p) + (p+0.3)/((p+1.2)**2*(p-0.6)**2)*0.05
        v_new = np.maximum(np.minimum(v + f * power - 0.0025 * slope \
                + npr.normal(0,sigma_v,n), max_speed), -max_speed)
        p_new = np.maximum(np.minimum(p + v_new \
                + npr.normal(0,sigma_p,n), max_position), min_position)
        idx = (p_new == max_position) & (v_new > 0)
        v_new[idx] = -v_new[idx]
        idx = (p_new == min_position) & (v_new < 0)
        v_new[idx] = 0

        reward = - np.power(f, 2) * 0.1
        idx_terminate = p >= goal_position
        reward[idx_terminate] += np.power(p[idx_terminate] \
                                 - goal_position, 2) * 100.0

        return p_new, v_new, reward

    #################################################################

    def update_params(self, n=100):

        if n != None:
            self.n = int(n)

    #################################################################

    def simulate_data(self):
        """
        Generate offline dataset that cosists of i.i.d. samples.

        INPUTS: n == scalar. Number of samples.

        ACTIONS: data == dictionary that has the following keys:
                    p == n-dimensional vector. Current positions.
                    v == n-dimensional vector. Current velocities.
                    f == n-dimensional vector. Current forces.
                    p_new == n-dimensional vector. Next positions.
                    v_new == n-dimensional vector. Next velocities.
                    reward == n-dimensional vector. Immediate rewards.
                 The initial state-action pairs (p,v,f) are drawn
                 from uniform distributions.
        """

        n = self.n
        min_position, max_position = self.min_position, self.max_position
        max_speed = self.max_speed
        min_action, max_action = self.min_action, self.max_action

        p = np.random.uniform(min_position, max_position, n)
        v = np.random.uniform(-max_speed, max_speed, n)
        f = np.random.uniform(min_action, max_action, n)
        p_new, v_new, reward = self.generate_next_state(p, v, f)

        self.data = {'p' : p, 'v' : v, 'f' : f,
                     'p_new' : p_new, 'v_new' : v_new,
                     'reward' : reward}

        return self.data

    #################################################################

    def compute_feature(self, p, v, f):
        """
        Compute the linear features at position p, velocity v
        and force f.

        INPUTS: p == n-dimensional vector. Input positions.
                v == n-dimensional vector. Input velocities.
                f == n-dimensional vector. Input forces.

        ACTIONS: Feature == d-by-n matrix.
                 The i-th column is the feature at the i-th (p,v,f).
        """

        n = self.n
        d = self.d

        feature_p = self.comp_feature_p(p)
        # print("feature_p", feature_p.shape)
        feature_v = self.comp_feature_v(v)
        # print("feature_v", feature_v.shape)
        feature_f = self.comp_feature_f(f)
        # print("feature_f", feature_f.shape)

        Feature = feature_p[:,None,None,:] * feature_v[None,:,None,:] \
                  * feature_f[None,None,:,:]
        Feature = Feature.reshape([d, n])

        return Feature

    #################################################################

    def comp_feature_p(self, p):
        """
        Compute the linear features at position p.

        INPUTS: p == n-dimensional vector. Input positions.

        ACTIONS: feature_p == dim_p-by-n matrix.
                 The i-th column is the feature at the i-th position.
        """

        dim_p = self.dim_p
        dim_p_1 = int(dim_p/2)
        dim_p_2 = dim_p - dim_p_1
        feature_p = np.append(np.cos(p[:,None] * np.arange(dim_p_1)), \
                    np.sin(p[:,None] * np.arange(1, dim_p_2+1)), axis=1).T

        return feature_p

    #################################################################

    def comp_feature_v(self, v):
        """
        Compute the linear features at velocity v.

        INPUTS: v == n-dimensional vector. Input velocities.

        ACTIONS: feature_v == dim_v-by-n matrix.
                 The i-th column is the feature at the i-th velocity.
        """

        dim_v = self.dim_v
        dim_v_1 = int((dim_v+1)/2)
        dim_v_2 = dim_v - dim_v_1
        feature_v = np.append(np.cos(v[:,None] * np.arange(dim_v_1)), \
                    np.sin(v[:,None] * np.arange(1, dim_v_2+1)), axis=1).T

        return feature_v

    #################################################################

    def comp_feature_f(self, f):
        """
        Compute the linear features at force f.

        INPUTS: f == n-dimensional vector. Input forces.

        ACTIONS: feature_f == dim_f-by-n matrix.
                 The i-th column is the feature at the i-th force.
                 feature_f(f) = (1, f, f^2, f^3).
        """

        dim_f = self.dim_f
        feature_f = np.power(f[:, None], np.arange(dim_f)).T

        return feature_f

    #################################################################

    def compute_val_fun(self, weight_Q, feature_p, feature_v):
        """
        Calculate the value function and greedy policy associated
        with a given Q-function.

        INPUTS: weight_Q == dim_p * dim_v * dim_f tensor.
                            The linear coefficients that define the
                            Q-function.
                feature_p == dim_p-by-n matrix.
                             The features of positions where the
                             value function is realized.
                feature_v == dim_v-by-n matrix.
                             The features of velocities where the
                             value function is realized.

        ACTIONS: vfun == n-dimensional vector.
                         The i-th entry is the value at the i-th (p,v).
                 f_star == n-dimension vector.
                           The i-th entry is the action at the i-th (p,v),
                           specified by the greedy policy associated with
                           the Q-function.
                 Recall that feature_f(f) = (1, f, f^2, f^3).
                 We use the property of third order polynomials to
                 calculate the maximum regarding the force, which has
                 an explicit analytical form.
        """

        n = feature_p.shape[1]

        dim_p, dim_v, dim_f = self.dim_p, self.dim_v, self.dim_f
        min_action, max_action = self.min_action, self.max_action

        weight_Q = weight_Q.reshape([dim_p, dim_v, dim_f])


        # qfun == n-by-dim_f matrix.
        #         For each index i, qfun[i,:] is the linear
        #         coefficients of Q-function(p[i], v[i]),
        #         i.e. Q-function(p[i], v[i], f)
        #                = qfun[i,0] + qfun[i,1]*f
        #                + qfun[i,2]*f^2 + qfun[i,3]*f^3.
        qfun = np.zeros([n, dim_f])
        for k in range(dim_f):
            qfun[:,k] = np.sum(feature_p.T * (weight_Q[:,:,k] @ feature_v).T, axis=1)
        # This line is equivalent to
        #   qfun[i,k] = feature_p[:,i].T @ weight_Q @ feature_v[:,i].

        # Calculate saddle points f0.
        # f0[i] = saddle point of Q-function(p[i], v[i], f).
        Delta = np.sqrt(qfun[:,2]**2 - 3*qfun[:,1]*qfun[:,3])
        f0 = - (qfun[:,2] + Delta) / (3*qfun[:,3])
        idx_de = np.abs(qfun[:,3]) < 1e-4
        f0[idx_de] = - qfun[idx_de,1] / (2*qfun[idx_de,2])
        f0 = np.where((f0 < min_action) | (f0 > max_action), np.nan, f0)

        # qfun_compare[i,0] = Q-function(p[i], v[i], f=-1)
        # qfun_compare[i,1] = Q-function(p[i], v[i], f=1)
        # qfun_compare[i,2] = Q-function(p[i], v[i], f=f0[i])
        qfun_compare = np.zeros([n,3])
        qfun_compare[:,0:2] = qfun @ np.power(np.array([-1,1])[:,None], np.arange(4)).T
        qfun_compare[:,2] = np.sum(qfun * np.power(f0[:,None], np.arange(4)), axis=1)

        # Return the value function and greedy policy.
        vfun = np.nanmax(qfun_compare, axis=1)
        idx_f = np.nanargmax(qfun_compare, axis=1)

        f_star = (idx_f * 2 - 1).astype(float)
        idx_temp = idx_f > 1.5
        f_star[idx_temp] = f0[idx_temp]

        return vfun, f_star

    #################################################################

    def evaluate_value(self, weight_Q, m=1000, L=1000):
        """
         Evaluate the value of a given policy.

         INPUTS: weight_Q == dim_p * dim_v * dim_f tensor.
                             The linear coefficients that define the
                             Q-function.
                 m == integer. Number of initial states to be evaluated.
                 L == integer. The length of each simulated trajectory.

         ACTIONS: Traj == dictionary. Information of each trajectory.
                  value == scalar. The average return of all trajectories.
                  The initial states are chosen as:
                    p = np.linspace(-0.6, -0.4, num=m, endpoint=True),
                    v = 0.
        """

        dim_p, dim_v, dim_f = self.dim_p, self.dim_v, self.dim_f
        discount = self.discount

        Traj_p = np.zeros([m, L+1])
        Traj_v = np.zeros([m, L+1])
        Traj_f = np.zeros([m, L])
        Traj_r = np.zeros([m, L])
        Traj_p[:, 0] = np.linspace(-0.6, -0.4, num=m, endpoint=True)

        for l in range(1, L+1):

            feature_p = self.comp_feature_p(Traj_p[:,l-1])
            feature_v = self.comp_feature_v(Traj_v[:,l-1])
            vfun, f_star = self.compute_val_fun(weight_Q, feature_p, feature_v)
            Traj_f[:, l-1] = f_star.copy()

            p_new, v_new, reward = self.generate_next_state(Traj_p[:,l-1], \
                                             Traj_v[:,l-1], Traj_f[:,l-1])
            Traj_p[:, l] = p_new.copy()
            Traj_v[:, l] = v_new.copy()
            Traj_r[:, l-1] = reward.copy()

        Traj = {'p' : Traj_p, 'v' : Traj_v, 'f' : Traj_f, 'r' : Traj_r}
        value = np.mean(Traj_r @ np.power(discount, np.arange(L)))

        return Traj, value

    #################################################################

    def conduct_FQI(self, ridge = 1e-2):
        """
         Conduct FQI

         INPUTS:

         ACTIONS:
        """

        discount = self.discount
        d = self.d

        Feature = self.compute_feature(self.data['p'], self.data['v'], self.data['f'])
        CovMt = Feature @ Feature.T
        feature_p_new = self.comp_feature_p(self.data['p_new'])
        feature_v_new = self.comp_feature_v(self.data['v_new'])
        reward = self.data['reward']

        num_iter = 500
        terminate = 0
        Bell_err = np.zeros(5)/0

        weight_Q = np.zeros(d)
        weight_Q_old = weight_Q.copy()

        weight_Q_min = weight_Q.copy()
        Bell_err_min = 1e5

        for i in range(num_iter):

            vfun, f_star = self.compute_val_fun(weight_Q, feature_p_new, feature_v_new)
            y = Feature @ (reward + discount * vfun)
            weight_Q = npl.solve(CovMt + ridge * np.eye(d), y)

            diff = npl.norm(weight_Q - weight_Q_old) / np.sqrt(d)

            if diff < Bell_err_min:
                Bell_err_min = diff
                weight_Q_min = weight_Q_old.copy()

            if (i+1)%10 == 0:
                print(self.n, i+1, diff)

            if diff > 5e-3:
                Bell_err = np.zeros(5)/0
                terminate = 0
            else:
                Bell_err[terminate] = diff
                terminate += 1

            if terminate == 5:
                weight_Q_output = weight_Q.copy()
                Bell_err_output = np.nanmean(Bell_err)
                break
            else:
                weight_Q_old = weight_Q.copy()

            if (i > 70) & (diff > 0.15):
                terminate = 100
                weight_Q_output = weight_Q_min
                Bell_err_output = Bell_err_min
                break

        if terminate < 5:
            weight_Q_output = weight_Q.copy()
            Bell_err_output = diff

        print("END of FQI", self.n, i, Bell_err_output, terminate)
        return weight_Q_output, Bell_err_output


    #################################################################
