import os
import sys
sys.path.append(os.getcwd())

import jax
import flax
import flax.optim as optim
import flax.nn as nn
import jax.numpy as np
import numpy as onp

import gym
import argparse

import copy
import nets

import argparse
import time


### ENVIRONMENTS
env = gym.make("CartPole-v1")
# env = gym.make("LunarLander-v2")


n_actions = env.action_space.n
n_states = env.observation_space.shape[0]

# command line args
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type = float, help = "learning rate")
parser.add_argument("--runs", type = int, help = "number of runs")
parser.add_argument("--hidden", type = int, help = "hidden layer size")
parsed_args = parser.parse_args()
lr = parsed_args.lr
runs = parsed_args.runs
hidden = parsed_args.hidden
T = 100

gamma = 0.99
print(f"T = {T} | lr = {lr} | runs = {runs}")


def flatten_dict(d):
    # completely flatten the parameters dictionary
    l = []
    for a in d:
        for b in d[a]:
            l.extend(d[a][b].flatten().tolist())
    return np.array(l)

def norm(a):
    return np.sqrt(np.sum(a * a))

def get_cos_sim(a, b):
    # assume you have two param dicts
    norm_a = 0
    norm_b = 0
    dot = 0
    for layer in a:
        for key in a[layer]:
            dot += np.sum(a[layer][key] * b[layer][key])
            norm_a += np.sum(a[layer][key] * a[layer][key])
            norm_b += np.sum(b[layer][key] * b[layer][key])
    norm_a = np.sqrt(norm_a)
    norm_b = np.sqrt(norm_b)
    return dot / norm_a / norm_b

# rng
rng = jax.random.PRNGKey(0)

# record data
cos_sim = onp.zeros((runs, T))
cos_sim[:] = onp.NaN

# time it
t_start = time.time()

for run in range(runs):
    # draw starting state
    rng, rng_input = jax.random.split(rng)
    s = env.reset()

    # hold trajectory data for recomputing the trace
    trajectory = onp.zeros((n_states + 1, T)) + 0.1

    # draw a new initial policy
    rng, rng_input = jax.random.split(rng)
    module = nets.Policy.partial(hidden = hidden, n_actions = n_actions)
    _, params = module.init(rng_input, s)
    model = nn.Model(module, params)

    # initialize the trace
    trace = {layer : {} for layer in params}
    for layer in params:
        for key in params[layer]:
            trace[layer][key] = onp.zeros_like(params[layer][key])
    
    t = 0
    done = False
    print(f"run = {run} | runs/s = {run / (time.time() - t_start)}")
    while done is not True and t < T:
        rng, rng_input = jax.random.split(rng)
        
        # get policy probabilities
        probs = model(s)
        a = jax.random.choice(rng_input, n_actions, p = probs)
        
        sp, r, done, _ = env.step(a.item())

        # print(s)
        trajectory[ : -1, t] = s
        trajectory[-1, t] = a

        
        # update trace
        grad_log = jax.grad(
            lambda m : np.log(m(s)[a])
        )(model).params
        

        # update the trace
        for layer in trace:
            for key in trace[layer]:
                trace[layer][key] = trace[layer][key] * gamma + (gamma ** t) * grad_log[layer][key]
        

        def collect_log_probs(subtraj, model, gamma):
            t = subtraj.shape[1]
            states = subtraj[: - 1, :].T
            actions = subtraj[-1, :].astype(int)
            
            probs = jax.vmap(
                lambda s : model(s)
            )(states)
            log_probs = np.log(probs[np.arange(t), actions])
            assert probs.shape == (t, n_actions), probs.shape
            assert log_probs.shape[0] == t
            return np.sum((gamma ** (t - 1)) * np.array(log_probs))

        full_trace = jax.grad(collect_log_probs, argnums = 1)(trajectory[:, : t + 1], model, gamma).params
        
        cos_sim[run, t] = get_cos_sim(trace, full_trace)
        if t == 0:
            # at time zero, traces should be equal
            for layer in trace:
                for key in trace[layer]:
                    assert np.array_equal(trace[layer][key], full_trace[layer][key]), f"layer = {layer} | key = {key}"
        
        t += 1

        # update policy
        for layer in params:
            for key in params[layer]:
                params[layer][key] = params[layer][key] + lr * r * trace[layer][key]

        s = sp


data_dir = "data/grad_check/"
fig_dir = "figs/grad_check/"
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
np.save(f"{data_dir}{env.spec.id}_lr={lr}_hidden={hidden}_cossim.npy", cos_sim)

