import argparse
import random
from copy import deepcopy

import gym
from gym import spaces
import numpy as np
import wandb
import itertools
from gym_repoman.envs import CollectEnv
from wrappers import *

import jax
import jax.numpy as jnp
from jax import jit, grad
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import GeneralConv, Dense, Relu, LogSoftmax, Flatten
from jax.nn.initializers import normal
from jax.tree_util import tree_map
from jax.flatten_util import ravel_pytree

def parse_args():
    parser = argparse.ArgumentParser("DQN experiments for Atari games")
    parser.add_argument("--seed", type=int, default=np.random.randint(10000), help="which seed to use")
    # Environment
    parser.add_argument("--env", type=str, default="ALE/Breakout-v5", help="name of the game")
    # Core DQN parameters
    parser.add_argument("--replay-buffer-size", type=int, default=int(1e5), help="replay buffer size")
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate for Adam optimizer") #1e-4
    parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")
    parser.add_argument("--num-steps", type=int, default=int(5e4), #1e5
                        help="total number of steps to run the environment for")
    parser.add_argument("--batch-size", type=int, default=64, help="number of transitions to optimize at the same time")
    parser.add_argument("--learning-starts", type=int, default=1000, help="number of steps before learning starts")
    parser.add_argument("--learning-freq", type=int, default=1,
                        help="number of iterations between every optimization step")
    parser.add_argument("--target-update-freq", type=int, default=1000,
                        help="number of iterations between every target network update")
    parser.add_argument("--use-double-dqn", type=bool, default=True, help="use double deep Q-learning")
    # e-greedy exploration parameters
    parser.add_argument("--eps-start", type=float, default=1.0, help="e-greedy start threshold")
    parser.add_argument("--eps-end", type=float, default=0.02, help="e-greedy end threshold")
    parser.add_argument("--eps-timesteps", type=float, default=100000, help="fraction of num-steps") #0.1
    # Reporting
    parser.add_argument("--print-freq", type=int, default=10, help="print frequency.")

    return parser.parse_args()

# These are functions we can use with tree_map and tree_map to perform operations on trees without flattening first
def add_two_trees(tree1, tree2):
    return tree1 + tree2

def subtract_two_trees(tree1, tree2):
    return tree1 - tree2

def ewc_reg_diag(params, old_params, fisher_tree):
    return (params - old_params)*fisher_tree

def ewc_reg_diag_no_l_rate(params, old_params, diag_tree):
    return diag_reg_step*(params - old_params)*diag_tree

def add_two_trees_ewc_full(grad_tree, reg_tree):
    return lr*grad_tree + lr*reg_step*reg_tree

def add_two_trees_ewc_full_diag(grad_tree, ewc_tree, diag_tree):
    return lr*grad_tree + lr*reg_step*ewc_tree + lr*diag_reg_step*diag_tree

def add_two_trees_ewc_full_no_step(grad_tree, ewc_tree):
    return grad_tree + reg_step*ewc_tree

def add_two_trees_selective_l2(grad_tree, l2_tree, selection_tree, max_tree): #0.0001
    return lr*grad_tree + selective_l2_step*l2_tree*(1 - (jnp.abs(selection_tree)>=1e-6*max_tree))

def add_two_trees_selective_l2_ewc_full_diag(grad_tree, l2_tree, selection_tree, max_tree, ewc_tree, diag_tree):#0.005
    return lr*grad_tree\
         + selective_l2_step*l2_tree*(1 - (jnp.abs(selection_tree)>=1e-6*max_tree))\
         + lr*reg_step*ewc_tree + lr*diag_reg_step*diag_tree

# Memory class
class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for DDPG agents.
    """

    def __init__(self, obs_dim, size):
        self.obs_buf = np.zeros((size, *obs_dim), dtype=np.float32)
        self.obs2_buf = np.zeros((size, *obs_dim), dtype=np.float32)
        self.act_buf = np.zeros((size), dtype=np.int32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def add(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size) 
        return self.obs_buf[idxs], self.act_buf[idxs], self.rew_buf[idxs], self.obs2_buf[idxs], self.done_buf[idxs]

# Network functions
@jit
def loss(policy_params, states, actions, target_q_values):
    q_values = policy_predict(policy_params, states)
    # q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze() #TODO
    q_values = q_values[jnp.arange(q_values.shape[0]),actions]
    return jnp.linalg.norm(q_values - target_q_values, ord=2)

@jit
def act(state, policy_params):
    """
    Select an action greedily from the Q-network given the state
    :param state: the current state
    :return: the action to take
    """
    state = jnp.expand_dims(state/255.0, axis=0)
    q_values = policy_predict(policy_params, state)
    return q_values.argmax(1)[0]

def optimise_td_loss(update, opt_state, memory, policy_params, target_params):
    """
    Optimise the TD-error over a single minibatch of transitions
    :return: the loss
    """
    states, actions, rewards, next_states, dones = memory.sample(batch_size)
    states = np.array(states) / 255.0
    next_states = np.array(next_states) / 255.0

    if use_double_dqn:
        max_next_action = policy_predict(policy_params, next_states).argmax(1)
        next_q_values = target_predict(target_params, next_states)
        # max_next_q_values = next_q_values.gather(1, max_next_action.unsqueeze(1)).squeeze() #TODO
        max_next_q_values = next_q_values[np.arange(next_q_values.shape[0]),max_next_action]
    else:
        next_q_values = target_predict(target_params, next_states)
        max_next_q_values = next_q_values.max(1)
    target_q_values = rewards + (1 - dones) * gamma * max_next_q_values

    opt_state = update(next(itercount), opt_state, states, actions, target_q_values)
    policy_params = get_params(opt_state)
    lossr = loss(policy_params, states, actions, target_q_values)

    del states
    del next_states

    return lossr, opt_state, policy_params

def optimise_td_loss_reg(update, opt_state, reg_grads, reg_pos, diag_tree, memory, policy_params, target_params):
    """
    Optimise the TD-error over a single minibatch of transitions
    :return: the loss
    """
    states, actions, rewards, next_states, dones = memory.sample(batch_size)
    states = np.array(states) / 255.0
    next_states = np.array(next_states) / 255.0

    if use_double_dqn:
        max_next_action = policy_predict(policy_params, next_states).argmax(1)
        next_q_values = target_predict(target_params, next_states)
        # max_next_q_values = next_q_values.gather(1, max_next_action.unsqueeze(1)).squeeze() #TODO
        max_next_q_values = next_q_values[np.arange(next_q_values.shape[0]),max_next_action]
    else:
        next_q_values = target_predict(target_params, next_states)
        max_next_q_values = next_q_values.max(1)
    target_q_values = rewards + (1 - dones) * gamma * max_next_q_values

    opt_state = update(next(itercount), opt_state, reg_grads, reg_pos, diag_tree, states, actions, target_q_values)
    policy_params = get_params(opt_state)
    lossr = loss(policy_params, states, actions, target_q_values)

    del states
    del next_states

    return lossr, opt_state, policy_params


def DQN(action_space, init_var):
    input_format = ("NCWH", "IWHO", "NCWH")
    return stax.serial(
        GeneralConv(input_format, 32, (8, 8), (4, 4), W_init = normal(init_var)), Relu,
        GeneralConv(input_format, 64, (4, 4), (2, 2), W_init = normal(init_var)), Relu,
        GeneralConv(input_format, 64, (3, 3), (1, 1), W_init = normal(init_var)), Relu,
        Flatten,
        Dense(512, W_init = normal(init_var)), Relu,
        Dense(action_space.n, W_init = normal(init_var)),
    )

def evaluate(env, policy_params):
   acc = 0
   state = env.reset()
   for _ in range(25):
      action = act(np.array(state), policy_params)
      state, reward, done, _ = env.step(int(action))
      acc += reward>0
      if done:
          break
   return acc

@jit
def mahalanobis_dist(params, states, actions, target_q_values, params_constant_ravel):
  grads = grad(loss)(params, states, actions, target_q_values)
  grads_flat, _ = ravel_pytree(grads)
  return jnp.dot(params_constant_ravel.T, grads_flat)

@jit
def mahalanobis_dist_loss_and_ewc(params, reg_grads, reg_pos, states, actions, target_q_values, params_constant_ravel):
  loss_grads = grad(loss)(params, states, actions, target_q_values)
  params_diff = tree_map(subtract_two_trees, params, reg_pos)
  const_flat_diff, unflattener = ravel_pytree(params_diff)
  elastic_grads = grad(riem_dist_ewc_from_reg_grad)(reg_grads, const_flat_diff)
  full_grads = tree_map(add_two_trees_ewc_full, loss_grads, elastic_grads)
  grads_flat, _ = ravel_pytree(full_grads)
  return jnp.dot(params_constant_ravel.T, grads_flat)

@jit
def l2_regularizer(params):
  flat_params, unflatten = ravel_pytree(params)
  return jnp.dot(flat_params.T, flat_params)

# Uses full Hessian to get the riemanian distance
@jit
def riem_dist_ewc_from_reg_grad(reg_grads, const_flat_diff):
  grads_flat, _ = ravel_pytree(reg_grads)
  return jnp.dot(const_flat_diff.T, grads_flat)

if __name__ == '__main__':
    # Hyper-parameters
    args = parse_args()
    np.random.seed(args.seed)
    random.seed(args.seed)
    use_double_dqn=args.use_double_dqn
    lr=args.lr
    batch_size=args.batch_size
    gamma=args.gamma
    itercount = itertools.count()
    rng = jax.random.PRNGKey(args.seed)
    network_init_var = 0.1 #0.01 #0.005 #0.1
    reg_step = 0.0 #5e-1$ #5e-1* #1e-2
    diag_reg_step = 0.0 #1e1$ #1e1*

    wandb.init(project='RL_EWC', entity='jax_rail')
    config = wandb.config
    config.learning_rate = lr
    config.seed = args.seed
    config.batch_size = batch_size
    config.gamma = gamma
    config.reg_step = reg_step
    config.diag_reg_step = diag_reg_step
    config.network_init_var = network_init_var

    # Exploration decay
    #eps_timesteps = args.eps_fraction * float(args.num_steps)
    eps_timesteps = args.eps_timesteps
    episode_rewards = [0.0]
    losses =  [0.0]

    # Setup Environment
    # assert "NoFrameskip" in args.env, "Require environment with no frameskip"
    tree = gym.envs.registry.all()._mapping.tree["ALE"]
    all_ale_env_names = [f"ALE/{name}-v5" for name, value in tree.items() if "ram" not in name]
    #print(args.env)
    
    # env_names_list = [args.env]
    env_names_list = ['RepoManBlue-v0','RepoManBeige-v0','RepoManPurple-v0',]
    num_tasks = len(env_names_list)
    train_env_list = []
    test_env_list = []
    for env_name in env_names_list:
        env = gym.make(env_name)
        env.seed(args.seed)
        # env = NoopResetEnv(env, noop_max=30)
        if env_name in all_ale_env_names:
            env = MaxAndSkipEnv(env, skip=4)
            env = EpisodicLifeEnv(env)
            env = FireResetEnv(env)
            env = WarpFrameGray(env)
        else:
            env = WarpFrameColour(env)
        env = PyTorchFrame(env)
        if env_name in all_ale_env_names:
            env = ClipRewardEnv(env)
            env = FrameStack(env, 4)
        train_env_list.append(env)
        
        env = gym.make(env_name)
        env.seed(args.seed)
        # env = NoopResetEnv(env, noop_max=30)
        if env_name in all_ale_env_names:
            env = MaxAndSkipEnv(env, skip=4)
            env = EpisodicLifeEnv(env)
            env = FireResetEnv(env)
            env = WarpFrameGray(env)
        else:
            env = WarpFrameColour(env)
        env = PyTorchFrame(env)
        if env_name in all_ale_env_names:
            env = ClipRewardEnv(env)
            env = FrameStack(env, 4)
        test_env_list.append(env) # TODO: Check that deepcopy actually works
    
    # Initialise memory, policy and target networks 
    policy_init_random_params, policy_predict = DQN(env.action_space, network_init_var)
    _, policy_params = policy_init_random_params(rng, (-1, *env.observation_space.shape))
    target_init_random_params, target_predict = DQN(env.action_space, network_init_var)
    _, target_params = target_init_random_params(rng, (-1, *env.observation_space.shape))
    target_params = policy_params.copy()
    init_zero_params, _ = DQN(env.action_space, 0)

    # Setup optimizer for policy network
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_state = opt_init(policy_params) 

    @jit
    def update(i, opt_state, states, actions, target_q_values):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, states, actions, target_q_values), opt_state)
 
    @jit 
    def update_ewc_full_diag(i, opt_state, reg_grads, reg_pos, diag_tree, states, actions, target_q_values):
        params = get_params(opt_state)
 
        # Loss and EWC Bit
        params_diff = tree_map(subtract_two_trees, params, reg_pos)
        const_flat_diff, unflattener = ravel_pytree(params_diff)
        loss_grads = grad(loss)(params, states, actions, target_q_values)
        elastic_grads = grad(riem_dist_ewc_from_reg_grad)(reg_grads, const_flat_diff)
        diag_grads = tree_map(ewc_reg_diag, params, reg_pos, diag_tree)

        full_grads = jax.tree_util.tree_map(add_two_trees_ewc_full_diag, loss_grads, elastic_grads, diag_grads)
        return opt_update(i, full_grads, opt_state)

    # Function used to iteratively update the fisher info diagonal (since its mean gradients squared averaged over the dataset)
    # The code complexity is just because I use tree operations so that we don't have to flatten every time (saves a lot of time)
    # See tree_map docs for more info
    def update_mean_tree_wrap(m, full_grads, datum_idx):
      if ((datum_idx+1) % 1000) == 0:
          print(datum_idx+1)
      def update_mean_tree(old_mean, new_obs):
          datum_idxr = datum_idx + 1
          return old_mean + (1/(datum_idxr))*(new_obs**2 - old_mean)
      return tree_map(update_mean_tree, m, full_grads)

    # Uses the update_mean_tree function to calculate fisher information
    def get_diag_fisher(opt_state, diag_tree, reg_pos, states, actions, target_q_values, task_index):
      params = get_params(opt_state)
      _, m = init_zero_params(rng, (-1, *env.observation_space.shape))
      _, V = init_zero_params(rng, (-1, *env.observation_space.shape))
      diag_grads = tree_map(ewc_reg_diag_no_l_rate, params, reg_pos, diag_tree)
      for datum_idx in range(states.shape[0]): #range(batch[0].shape[0]):jnp.expand_dims(state/255.0, axis=0)
          loss_grads = grad(loss)(params, jnp.expand_dims(states[datum_idx], axis=0),\
                                          jnp.expand_dims(actions[datum_idx], axis=0),\
                                          jnp.expand_dims(target_q_values[datum_idx], axis=0))
          if task_index > 0:
              full_grads = tree_map(add_two_trees, loss_grads, diag_grads)
          else:
              full_grads = loss_grads
          m = update_mean_tree_wrap(m, full_grads, datum_idx)
      return m    

    _, reg_pos = init_zero_params(rng, (-1, *env.observation_space.shape))
    _, diag_tree = init_zero_params(rng, (-1, *env.observation_space.shape))

    test_accs = [[] for _ in range(num_tasks)] # np.zeros((num_tasks, num_tasks*args.num_steps))
    # Training loop
    for i,env in enumerate(train_env_list):
        print("train_env: ", env_names_list[i]) 
        state = env.reset()
        memory = ReplayBuffer(env.observation_space.shape, args.replay_buffer_size)
        for t in range(args.num_steps):
            fraction = min(1.0, float(t) / eps_timesteps)
            eps_threshold = args.eps_start + fraction * (args.eps_end - args.eps_start)
            sample = random.random()
            if sample > eps_threshold:
                action = act(np.array(state), policy_params)
            else:
                action = env.action_space.sample()

            next_state, reward, done, _ = env.step(int(action))
            memory.add(state, action, reward, next_state, float(done))
            state = next_state

            episode_rewards[-1] += reward
            if done:
                state = env.reset()
                episode_rewards.append(0.0)

            if t > args.learning_starts and t % args.learning_freq == 0:
                if i == 0:
                    lossr, opt_state, policy_params = optimise_td_loss(update, opt_state, memory, policy_params, target_params)
                else:
                    lossr, opt_state, policy_params = optimise_td_loss_reg(update_ewc_full_diag, opt_state, reg_grads, reg_pos, diag_tree,\
                                                                                                           memory, policy_params, target_params)

            if t > args.learning_starts and t % args.target_update_freq == 0:
                target_params = policy_params.copy()
            
            if t % 100 == 0:
                for j,test_env in enumerate(test_env_list):
                    test_accr = evaluate(test_env, policy_params)
                    test_accs[j].append(test_accr)

            num_episodes = len(episode_rewards)
            if done and args.print_freq is not None and len(episode_rewards) % args.print_freq == 0:
                mean_100ep_reward = round(np.mean(episode_rewards[-101:-1]), 2) #[-101:-1]
                print("********************************************************")
                print('Current Env: ', env.unwrapped.spec.id)
                print("steps: {}".format(t))
                print("episodes: {}".format(num_episodes))
                print("mean 100 episode reward: {}".format(mean_100ep_reward))
                print("% time spent exploring: {}".format(int(100 * eps_threshold)))
                print("Recent Accuracy T0: ", round(np.mean(test_accs[0][-101:-1]), 2))
                print("Recent Accuracy T1: ", round(np.mean(test_accs[1][-101:-1]), 2))
                print("Recent Accuracy T2: ", round(np.mean(test_accs[2][-101:-1]), 2))
                print("********************************************************")
                wandb.log({'step': t, 'episodes': num_episodes, 'Mean Reward':mean_100ep_reward, 'Time Explored': int(100 * eps_threshold)})

        # Get fisher info and reg point used for next task
        # Note for the efficient Hessian-Vector method of getting the riemannian distance we just need to store the first derivative
        # of the network for the current task.
        params = get_params(opt_state)
        batch_sample_size = 10000
        states, actions, rewards, next_states, dones = memory.sample(batch_sample_size) #.obs_buf, memory.act_buf, memory.rew_buf, memory.obs2_buf, memory.done_buf
        states = np.array(states) / 255.0
        next_states = np.array(next_states) / 255.0
        
        if use_double_dqn:
            max_next_action = policy_predict(policy_params, next_states).argmax(1)
            next_q_values = target_predict(target_params, next_states)
            # max_next_q_values = next_q_values.gather(1, max_next_action.unsqueeze(1)).squeeze() #TODO
            max_next_q_values = next_q_values[np.arange(next_q_values.shape[0]),max_next_action]
        else:
            next_q_values = target_predict(target_params, next_states)
            max_next_q_values = next_q_values.max(1)
        target_q_values = rewards + (1 - dones) * gamma * max_next_q_values

        print("After DQN Bit")
        next_grads = grad(loss)(params, states, actions, target_q_values)
        print("1")
        if i > 0:
            params_diff = tree_map(subtract_two_trees, params, reg_pos)
            const_flat_diff, unflattener = ravel_pytree(params_diff)
            elastic_grads = grad(riem_dist_ewc_from_reg_grad)(reg_grads, const_flat_diff)
            next_grads = tree_map(add_two_trees_ewc_full_no_step, next_grads, elastic_grads)
        if not i == (num_tasks-1):
            diag_tree = get_diag_fisher(opt_state, diag_tree, reg_pos, states, actions, target_q_values, i)
        print("2")
        reg_grads = next_grads
        reg_pos = params

        np.save(str(args.seed)+"_test_accs", test_accs)
