import os
import sys
sys.path.append(os.getcwd())

import jax
import jax.numpy as np
import numpy as onp

import src.envs.mdps as mdps
import copy

import src.utils.tab_rl as tab_rl

import argparse
import time

# matplotlib stuff
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns

plt.rcParams['axes.xmargin'] = 0
plt.rcParams['axes.ymargin'] = 0
matplotlib.rcParams.update({'font.size': 40})


### ENVIRONMENTS
env = mdps.SwitchStay()
# env = mdps.AdaptChain()

n_states = env.n_states
n_actions = env.n_actions
rho = onp.zeros(n_states)
for i in env.starting_states:
    rho[i] = 1. / env.n_starting_states

# command line args
parser = argparse.ArgumentParser()
parser.add_argument("--T", type = int, help = "number of timesteps")
parser.add_argument("--lr", type = float, help = "learning rate")
parser.add_argument("--runs", type = int, help = "number of runs")
parsed_args = parser.parse_args()
T = parsed_args.T
lr = parsed_args.lr
runs = parsed_args.runs
try:
    gamma = env.gamma
except:
    print("Setting default gamma = 0.9...")
    gamma = 0.9
print(f"env = {env.name} | T = {T} | lr = {lr} | runs = {runs}")


def to_policy(logits):
    return jax.nn.softmax(logits, axis = 1)

@jax.jit
def get_full_trace(subtraj, logits, gamma):
    # calculate e_t for the subtrajectory pi
    # e_t = \sum \gamma^t \nabla \log \pi(A_t | S_t)
    # assume subtraj = [(s_0, a_0), ...]
    t = subtraj.shape[1]
    grad_log = jax.grad(collect_log_probs, argnums = 1)(subtraj, logits, gamma)
    return grad_log

@jax.jit
def collect_log_probs(subtraj, logits, gamma):
    subtraj = subtraj.astype(int)
    policy = to_policy(logits)
    t = subtraj.shape[1]
    log_probs = np.log(policy[subtraj[0, :], subtraj[1, :]])
    return  np.sum((gamma ** (t - 1)) * np.array(log_probs))

@jax.jit
def get_true_pg(P, r, gamma, logits, rho, Q = None):
    pi = jax.nn.softmax(logits, axis = -1)
    dpi = tab_rl.get_d_pi(P, gamma, pi, rho) # assumes unnormalized stationary distribution
    if Q is None:
        Q = tab_rl.get_Q(P, r, gamma, pi)
    grad_pi = jax.jacobian(
        lambda l: jax.nn.softmax(l, axis = -1)
    )(logits) # first indices index the desired (s, a) pair at which the gradient is taken
    return np.sum(dpi[:, np.newaxis, np.newaxis, np.newaxis] * Q[:, :, np.newaxis, np.newaxis] * grad_pi, axis = (0, 1))

def norm(a):
    return np.sqrt(np.sum(a * a))

def get_cos_sim(a, b):
    return np.sum(a * b) / (norm(a) * norm(b) + 1e-5)

# rng
rng = jax.random.PRNGKey(0)

# record data
cos_sim = onp.zeros((runs, T))
true_pg_cos_sim = onp.zeros((runs, T))
reinforce_cos_sim = onp.zeros(runs)

# time it
t_start = time.time()


for run in range(runs):
    # draw starting state
    rng, rng_input = jax.random.split(rng)
    s = jax.random.choice(rng_input, n_states, p = rho)

    # hold trajectory data for recomputing the trace
    trajectory = onp.zeros((2, T))

    # draw a new initial policy
    rng, rng_input = jax.random.split(rng)
    logits = jax.random.uniform(rng_input, shape = (n_states, n_actions))
    initial_logits = copy.deepcopy(logits)

    # initialize the trace
    trace = onp.zeros_like(logits)
    reinforce_trace = onp.zeros_like(logits)

    # running pg estimate to compare to true pg
    running_pg = onp.zeros_like(logits)
    reinforce_running_pg = onp.zeros_like(logits)

    print(f"run = {run} | runs/s = {run / (time.time() - t_start)}")
    for t in range(T):
        rng, rng_input = jax.random.split(rng)
        probs = to_policy(logits)
        a = jax.random.choice(rng_input, n_actions, p = probs[s, :])
        rng, rng_input = jax.random.split(rng)
        sp = jax.random.choice(rng_input, n_states, p = env.P[s, a, :])
        r = env.r[s, a]

        trajectory[0, t] = s
        trajectory[1, t] = a

        
        # update trace
        grad_log = jax.grad(
            lambda l : np.log(to_policy(l)[s, a])
        )(logits)
        reinforce_grad_log = jax.grad(
            lambda l : np.log(to_policy(l)[s, a])
        )(initial_logits)
        # assert grad_log.shape == logits.shape, (logits.shape, grad_log.shape)
        trace = trace * gamma + (gamma ** t) * grad_log
        reinforce_trace = reinforce_trace * gamma + (gamma ** t) * reinforce_grad_log
        
        running_pg += r * trace
        reinforce_running_pg += r * reinforce_trace

        # compare trace to true trace
        full_trace = get_full_trace(trajectory[:, : t + 1], logits, gamma) # true trace up to this point for comparison
        assert full_trace.shape == trace.shape, (full_trace.shape, trace.shape)
        
        true_pg = get_true_pg(env.P, env.r, gamma, logits, rho)
        cos_sim[run, t] = get_cos_sim(trace, full_trace)
        true_pg_cos_sim[run, t] = get_cos_sim(true_pg, running_pg)
        if t == 0:
            # correlation should be perfect at the beginning
            assert onp.array_equal(full_trace, trace), (full_trace, trace)

        # update policy
        logits += lr * r * trace

        s = sp

    # compare to reinforce: calculate the counterfactual trace of the initial policy
    reinforce_cos_sim[run] = get_cos_sim(reinforce_running_pg, get_true_pg(env.P, env.r, gamma, initial_logits, rho))

data_dir = "data/grad_check/"
fig_dir = "figs/grad_check/"
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
np.save(f"{data_dir}{env.name}_lr={lr}_cossim.npy", cos_sim)
np.save(f"{data_dir}{env.name}_lr={lr}_truepgcossim.npy", true_pg_cos_sim)
np.save(f"{data_dir}{env.name}_lr={lr}_reinforcecossim.npy", reinforce_cos_sim)
