# Copyright 2017 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

r"""Optimal policy for LQR levels.

LQR control problem is described in
https://en.wikipedia.org/wiki/Linear-quadratic_regulator#Infinite-horizon.2C_discrete-time_LQR
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import logging
from dm_control.mujoco import wrapper
import numpy as np
from six.moves import range

try:
    import scipy.linalg as sp  # pylint: disable=g-import-not-at-top
except ImportError:
    sp = None


def _solve_dare(a, b, q, r):
    """Solves the Discrete-time Algebraic Riccati Equation (DARE) by iteration.

    Algebraic Riccati Equation:
    ```none
    P_{t-1} = Q + A' * P_{t} * A -
              A' * P_{t} * B * (R + B' * P_{t} * B)^{-1} * B' * P_{t} * A
    ```

    Args:
      a: A 2 dimensional numpy array, transition matrix A.
      b: A 2 dimensional numpy array, control matrix B.
      q: A 2 dimensional numpy array, symmetric positive definite cost matrix.
      r: A 2 dimensional numpy array, symmetric positive definite cost matrix

    Returns:
      A numpy array, a real symmetric matrix P which is the solution to DARE.

    Raises:
      RuntimeError: If the computed P matrix is not symmetric and
        positive-definite.
    """
    p = np.eye(len(a))
    for _ in range(1000000):
        a_p = a.T.dot(p)  # A' * P_t
        a_p_b = np.dot(a_p, b)  # A' * P_t * B
        # Algebraic Riccati Equation.
        p_next = (
            q
            + np.dot(a_p, a)
            - a_p_b.dot(np.linalg.solve(b.T.dot(p.dot(b)) + r, a_p_b.T))
        )
        p_next += p_next.T
        p_next *= 0.5
        if np.abs(p - p_next).max() < 1e-12:
            break
        p = p_next
    else:
        logging.warning("DARE solver did not converge")
    try:
        # Check that the result is symmetric and positive-definite.
        np.linalg.cholesky(p_next)
    except np.linalg.LinAlgError:
        raise RuntimeError(
            "ARE solver failed: P matrix is not symmetric and " "positive-definite."
        )
    return p_next


def solve(env):
    """Returns the optimal value and policy for LQR problem.

    Args:
      env: An instance of `control.EnvironmentV2` with LQR level.

    Returns:
      p: A numpy array, the Hessian of the optimal total cost-to-go (value
        function at state x) is V(x) = .5 * x' * p * x.
      k: A numpy array which gives the optimal linear policy u = k * x.
      beta: The maximum eigenvalue of (a + b * k). Under optimal policy, at
        timestep n the state tends to 0 like beta^n.

    Raises:
      RuntimeError: If the controlled system is unstable.
    """
    n = env.physics.model.nq  # number of DoFs
    m = env.physics.model.nu  # number of controls

    # Compute the mass matrix.
    mass = np.zeros((n, n))
    wrapper.mjbindings.mjlib.mj_fullM(env.physics.model.ptr, mass, env.physics.data.qM)

    # Compute input matrices a, b, q and r to the DARE solvers.
    # State transition matrix a.
    stiffness = np.diag(env.physics.model.jnt_stiffness.ravel())
    damping = np.diag(env.physics.model.dof_damping.ravel())
    dt = env.physics.model.opt.timestep

    j = np.linalg.solve(-mass, np.hstack((stiffness, damping)))
    a = np.eye(2 * n) + dt * np.vstack(
        (dt * j + np.hstack((np.zeros((n, n)), np.eye(n))), j)
    )

    # Control transition matrix b.
    b = env.physics.data.actuator_moment.T
    bc = np.linalg.solve(mass, b)
    b = dt * np.vstack((dt * bc, bc))

    # State cost Hessian q.
    q = np.diag(np.hstack([np.ones(n), np.zeros(n)]))

    # Control cost Hessian r.
    r = env.task.control_cost_coef * np.eye(m)

    if sp:
        # Use scipy's faster DARE solver if available.
        solve_dare = sp.solve_discrete_are
    else:
        # Otherwise fall back on a slower internal implementation.
        solve_dare = _solve_dare

    # Solve the discrete algebraic Riccati equation.
    p = solve_dare(a, b, q, r)
    k = -np.linalg.solve(b.T.dot(p.dot(b)) + r, b.T.dot(p.dot(a)))

    # Under optimal policy, state tends to 0 like beta^n_timesteps
    beta = np.abs(np.linalg.eigvals(a + b.dot(k))).max()
    if beta >= 1.0:
        raise RuntimeError("Controlled system is unstable.")
    return p, k, beta
