"""Algorithm 1.

TODO: add more general nominal trajectories.
"""

import jax
import jax.numpy as jnp
import numpy as onp
from tqdm import tqdm

import scripts.agent as r_agent
import scripts.solver as r_solver


def min_distance(a_arr, B, M, w, obstacles):  # pylint: disable=invalid-name
    l2_arr = jnp.linalg.norm(
        a_arr + jnp.matmul(jnp.matmul(B, M), w)[:2].reshape(-1, 1), axis=0)
    l2_arr -= obstacles.at[-1].get().reshape(1, -1)
    return jnp.min(l2_arr)


def transpose(pytree):
    leaves, treedef = jax.tree_flatten(pytree)
    return [
      treedef.unflatten([leaf[i]
                         for leaf in leaves])
      for i in range(len(leaves[0]))
    ]
    
def algorithm_1(  # pylint: disable=invalid-name
    env,
    key,
    env_state=None,
    w=None,
    w_mean=0.0,
    w_std=0.5,
    eta=1e-3,
    d_x=4,
    d_u=2,
    H=10,
    H_p=5,
    K= 0. * jnp.array([[0.02786103, 0., 0.23769369, 0.],
                      [0., 0.00303891, 0., 0.07801963]]),
    T=100,
    scan=False,
    ):
    
    """Algorithm 1."""
    if w is None:
        if False:
            print('Using Gaussian Noise')
            w_key_a = jax.random.fold_in(key, 0)
            w_key_b = jax.random.fold_in(key, 1)
            w_a = jax.random.normal(w_key_a, shape=(1, T + 1)) * w_std + w_mean
            w_b = jax.random.normal(w_key_b, shape=(1, T + 1)) * w_std + w_mean
            zeros = jnp.zeros_like(w_a)
            w = jnp.concatenate((w_a, zeros, w_b, zeros))
        elif False:
            print('Using Uniform Noise')
            w_key_a = jax.random.fold_in(key, 0)
            # w_key_b = jax.random.fold_in(key, 1)
            w_a = jax.random.uniform(w_key_a, shape=(1, T + 1)) - 0.5 
            # w_b = jax.random.uniform(w_key_b, shape=(1, T + 1)) - 0.5
            zeros = jnp.zeros_like(w_a)
            w = jnp.concatenate((w_a, zeros, zeros, zeros))
        elif False:
            print('Using Sin Noise')
            w_key_a = jax.random.fold_in(key, 0)
            w_a = jax.random.uniform(w_key_a, shape=(1, T + 1)) - 0.5 
            zeros = jnp.zeros_like(w_a)
            wReal = onp.zeros((1,T+1))
            for t in range(T):
                tmp = float(t)
                if ((tmp/2. % 50.) >= 4.49) and ((tmp/2. % 50.) <= 11.51):
                    #print('Setting Positive!')
                    wReal[0, t+1] = 0.5
                elif ((tmp/2. % 50.) >= 11.51) and ((tmp/2. % 50.) <= 18.51):
                    #print('Setting Negative!')
                    wReal[0, t+1] = -0.5
                else:
                    wReal[0, t+1] = onp.random.rand(1)-0.5
                    
            w = jnp.concatenate((jnp.array(wReal), zeros, zeros, zeros))
            # print(w[0,:])
        elif True: 
            print('Using Sin Noise')
            w_key_a = jax.random.fold_in(key, 0)
            w_a = jax.random.uniform(w_key_a, shape=(1, T + 1)) - 0.5 
            zeros = jnp.zeros_like(w_a)
            wReal = onp.zeros((1,T+1))
            for t in range(T):
                wReal[0, t+1] = 0.3*onp.sin(0.05*float(t))
                    
            w = jnp.concatenate((jnp.array(wReal), zeros, zeros, zeros))
            # print(w[0,:])
        else:
            print('Using adversarial noise!')
            w_key_a = jax.random.fold_in(key, 0)
            w_a = jax.random.uniform(w_key_a, shape=(1, T + 1)) - 0.5 
            zeros = jnp.zeros_like(w_a)
            wReal = onp.zeros((1,T+1))
            for t in range(T):
                wReal[0, t] = float(onp.random.randint(0, 2))-0.5
            
            w = jnp.concatenate((jnp.array(wReal), zeros, zeros, zeros))
    
    def loop(carry, t):
        agent_state, env_state = carry
        a_arr = jnp.matmul(A_cl,
                           env_state.arr)[:2] - env_state.ten_obstacles_observed[:2]
        # print(f"step {t} ten obstacles {env_state.ten_obstacles_observed[:2]}")
        loss = min_distance(a_arr, env.B, agent_state.M,
                            agent_state.w_history.flatten(),
                            env_state.ten_obstacles_observed)
        a_arr = jnp.matmul(env.B_inv, a_arr)
        b = agent_state.w_history.flatten()
        b0 = jnp.matmul(A_cl,
                           env_state.arr)[:2]
        M_opt, _ = r_solver.min_max_solver(  # pylint: disable=invalid-name
            a_arr,
            agent_state.M,
            b,
            T,
            eta)

        # update agent state
        agent_state = agent_state.replace(M=M_opt)
        agent_state, action = agent(agent_state, env_state, w=w.at[:, t - 1].get())

        # evolve environment
        action = action * jnp.array([[1], [0]])  # set y accel to 0
        action = jnp.maximum(jnp.minimum(action, jnp.array([[0.75], [0]])), jnp.array([[-0.75], [0]]))
        # print(action)
        env_state, _ = env(env_state, action, w=w.at[:, t].get())

        return (agent_state, env_state), (loss, action, env_state, agent_state)
       
    # set parameters
    # init env and agent
    if env_state is None:
        env_state = env.init()
    agent = r_agent.Agent.create(d_x=d_x, d_u=d_u, H=H, K=K)
    init_u0 = jnp.zeros((d_u, H_p))  # dummy uo
    agent_state = agent.init(u_0=init_u0)
    A_cl = env.A + jnp.matmul(env.B, agent.K)  # pylint: disable=invalid-name
    
    carry = (agent_state, env_state)
    if scan:
        carry, results = jax.lax.scan(loop, carry, jnp.array(range(1, T + 1)))
        results = transpose(results)

    else:
        results = []
        for t in range(1, T+1):
            # for t in tqdm(range(1, T + 1)):
            if t%40 == 0:
                print(t)
            
            carry, result = loop(carry, t)
            results.append(result)

    return results, w