"""Agent / policy implementation.

policy parametrization
- d_u: dimension of action
- d_x: dimension of state
- H: horizon
- P: exponential distribution, another lambda is hypermeter.
"""
import deluca.core
import jax
import jax.numpy as jnp


class AgentState(deluca.Obj):
  # shape: (d_u, d_x*H)
  M: jnp.array = deluca.field()

  # shape: (d_x, H)
  w_history: jnp.array = deluca.field()
  step: float = deluca.field(0.)

  # shape: d_u x H_p (planner horizon)
  u_0: jnp.array = deluca.field()


class Agent(deluca.Agent):
  """Agent class."""
  d_u: int = deluca.field(2)
  d_x: int = deluca.field(4)
  H: int = deluca.field(10)
  seed: int = deluca.field(0)
  # K: jnp.array = -1 * jnp.array([[0.07956252, 0., 0.40676188, 0.],
  #                                [0., 0.00303891, 0., 0.07801963]])
  K: jnp.array = 0 * jnp.array([[0.15746095, 0., 0.58285148, 0.],
                                 [0., 0.00303891, 0., 0.07801963]])

  def init(self, u_0):
    print(self.K)
    return AgentState(
        M=jax.random.uniform(
            jax.random.PRNGKey(self.seed), (self.d_u, self.d_x * (self.H + 1))),
        u_0=u_0,
        w_history=jnp.zeros((self.d_x, self.H + 1)))

  def __call__(self, agent_state, racer_state, w=jnp.zeros(4)):
    # agent_state: M, u_0, step
    # racer_state: [x, y, x_dot, y_dot]
    # obstacles: List([x, y, x_dot, y_dot, r])
    ut_o = agent_state.u_0.at[:,
                              jnp.array(agent_state.step, int)].get().reshape(
                                  -1, 1)
    Kxt = jnp.matmul(self.K, racer_state.arr)  # pylint: disable=invalid-name

    # Roll w history
    w_history = jnp.roll(
        agent_state.w_history.at[:, :-1].get(), shift=1, axis=1).at[:, 0].set(w)
    w_history = jnp.hstack((w_history, jnp.ones((self.d_x, 1))))
    Mw = jnp.matmul(agent_state.M, w_history.reshape(-1)).reshape(-1, 1)  # pylint: disable=invalid-name

    # print(f"ut_o {ut_o}, Kxt {Kxt}, Mw {Mw}")
    action = ut_o + Kxt + Mw
    # print(f"uo {ut_o}, kxt {Kxt}, mw {Mw}")
    return agent_state.replace(
        w_history=w_history, step=agent_state.step + 1), action