import time
import itertools
import sys

import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
from jax.example_libraries.stax import Dense, Relu, LogSoftmax
from jax.nn.initializers import variance_scaling, normal, glorot_normal, he_normal
from jax.tree_util import tree_map
from jax.flatten_util import ravel_pytree
import datasets

# These are functions we can use with tree_map and tree_multimap to perform operations on trees without flattening first
def add_two_trees(tree1, tree2):
    return tree1 + tree2

def subtract_two_trees(tree1, tree2):
    return tree1 - tree2

def scale_grads_sgd(grad_tree):
    return step_size*grad_tree

def ewc_reg_diag(params, old_params, fisher_tree):
    return (params - old_params)*fisher_tree

def ewc_reg_diag_no_l_rate(params, old_params, diag_tree):
    return diag_reg_step*(params - old_params)*diag_tree

def add_two_trees_ewc_full(grad_tree, reg_tree):
    return step_size*grad_tree + step_size*reg_step*reg_tree

def add_two_trees_ewc_full_diag(grad_tree, ewc_tree, diag_tree):
    return step_size*grad_tree + step_size*reg_step*ewc_tree + step_size*diag_reg_step*diag_tree

def add_two_trees_ewc_full_no_step(grad_tree, ewc_tree):
    return grad_tree + reg_step*ewc_tree

def add_two_trees_selective_l2(grad_tree, l2_tree, selection_tree, max_tree):            #0.005
    return step_size*grad_tree + selective_l2_step*l2_tree*(1 - (jnp.abs(selection_tree)>=0.0001*max_tree))

def add_two_trees_selective_l2_ewc_full_diag(grad_tree, l2_tree, selection_tree, max_tree, ewc_tree, diag_tree):
    return step_size*grad_tree\
         + selective_l2_step*l2_tree*(1 - (jnp.abs(selection_tree)>=0.0001*max_tree))\
         + step_size*reg_step*ewc_tree + step_size*diag_reg_step*diag_tree

# Standard Cross-Entropy Loss
def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

# Online loss used to calculate Fisher Information
def online_loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.sum(preds * targets)

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

@jit
def mahalanobis_dist(params, batch, params_constant_ravel):
  inputs, targets = batch
  grads = grad(loss)(params, batch)
  grads_flat, _ = ravel_pytree(grads)
  return jnp.dot(params_constant_ravel.T, grads_flat)

@jit
def mahalanobis_dist_loss_and_ewc(params, reg_grads, reg_pos, batch, params_constant_ravel):
  inputs, targets = batch
  loss_grads = grad(loss)(params, batch)
  params_diff = tree_multimap(subtract_two_trees, params, reg_pos)
  const_flat_diff, unflattener = ravel_pytree(params_diff)
  elastic_grads = grad(riem_dist_ewc_from_reg_grad)(reg_grads, const_flat_diff)
  full_grads = tree_multimap(add_two_trees_ewc_full, loss_grads, elastic_grads)
  grads_flat, _ = ravel_pytree(full_grads)
  return jnp.dot(params_constant_ravel.T, grads_flat)

@jit
def l2_regularizer(params):
  flat_params, unflatten = ravel_pytree(params)
  return jnp.dot(flat_params.T, flat_params)

# Uses full Hessian to get the riemanian distance
@jit
def riem_dist_ewc_from_reg_grad(reg_grads, const_flat_diff):
  grads_flat, _ = ravel_pytree(reg_grads)
  return jnp.dot(const_flat_diff.T, grads_flat)

if __name__ == "__main__":
  rng = random.PRNGKey(np.random.randint(1000))

  # Hyper-parameters
  step_size = 0.01
  reg_step = 0 #1e-3 #1e-3 #5e-4*
  diag_reg_step = 0 #1e2 #2e3 #1e3*
  selective_l2_step = 1e-5 #1e-5 #0.00001 #0.00005
  num_epochs = 150
  batch_size = 128
  init_var = 0.005 #0.1 #0.005
  momentum_mass = 0.0
  num_tasks = 5
  if len(sys.argv) > 1:
      run_idx = int(sys.argv[1])
      rng = random.PRNGKey(run_idx)
      np.random.seed(run_idx)
  else:
      run_idx = 0
      rng = random.PRNGKey(np.random.randint(1000))
      np.random.seed(np.random.randint(1000))

  # Initialize our network
  init_random_params, predict = stax.serial(
    Dense(1024, W_init = normal(init_var)), Relu,
    Dense(1024, W_init = normal(init_var)), Relu,
    Dense(10, W_init = normal(init_var)), LogSoftmax)

  # Init a tree of zeros with same shape as our network (used to store fisher information and regularization point)
  init_zero_tree, _ = stax.serial(
    Dense(1024, W_init = normal(0), b_init = normal(0)), Relu, #1024
    Dense(1024, W_init = normal(0), b_init = normal(0)), Relu, #1024
    Dense(10, W_init = normal(0), b_init = normal(0)), LogSoftmax)

  opt_init, opt_update, get_params = optimizers.momentum(1.0, mass=momentum_mass)

  # Gen data for all our tasks
  permute_seeds = [42, 33, 897, 90, 543] #np.random.randint(0,1000000,5) #[31, 37, 24,  5, 20, 49, 63, 46, 84, 68]
  train_images, train_labels, test_images, test_labels, num_batches, batches = datasets.mnist(batch_size, True, permute_seeds[0])
  task_loaders = [batches]
  task_ims = [train_images]
  task_labels = [train_labels]
  test_ims = [test_images]
  test_labs = [test_labels]
  for task_idx in range(num_tasks-1):
      train_images, train_labels, test_images, test_labels, num_batches, batches = datasets.mnist(batch_size, True, permute_seeds[task_idx+1])
      task_loaders.append(batches)
      task_ims.append(train_images)
      task_labels.append(train_labels)
      test_ims.append(test_images)
      test_labs.append(test_labels)

  
  # Standard SGD used to train the first task
  @jit
  def update_standard(i, opt_state, batch):
    params = get_params(opt_state)
    grads = grad(loss)(params, batch)
    full_grads = jax.tree_util.tree_map(scale_grads_sgd, grads)
    return opt_update(i, full_grads, opt_state)
 
  # Learning rule which uses SGD and ewc regularization where ewc now uses the full fisher info (Hessian)
  @jit
  def update_ewc_full_diag(i, opt_state, reg_grads, reg_pos, diag_tree, batch):
    params = get_params(opt_state)
    params_diff = tree_multimap(subtract_two_trees, params, reg_pos)
    const_flat_diff, unflattener = ravel_pytree(params_diff)
    loss_grads = grad(loss)(params, batch)
    elastic_grads = grad(riem_dist_ewc_from_reg_grad)(reg_grads, const_flat_diff)
    diag_grads = tree_multimap(ewc_reg_diag, params, reg_pos, diag_tree)
    full_grads = tree_multimap(add_two_trees_ewc_full_diag, loss_grads, elastic_grads, diag_grads)
    return opt_update(i, full_grads, opt_state)
  

  @jit
  def update_selective(i, opt_state, batch):
    params = get_params(opt_state)
    const_flat_params, unflattener = ravel_pytree(params)
    important_params = grad(mahalanobis_dist)(params, batch, const_flat_params)
    max_params = important_params.copy()
    for i in [0,2,4]:
        max_params[i] = tuple([jnp.ones(max_params[i][0].shape)*jnp.max(max_params[i][0]),\
                              jnp.ones(max_params[i][1].shape)*jnp.max(max_params[i][1])])
    loss_grads = grad(loss)(params, batch)
    reg_grads = grad(l2_regularizer)(params)
    full_grads = jax.tree_util.tree_multimap(add_two_trees_selective_l2, loss_grads, reg_grads, important_params, max_params)
    return opt_update(epoch, full_grads, opt_state) 

  @jit 
  def update_selective_ewc_full_diag(i, opt_state, reg_grads, reg_pos, diag_tree, batch):
    params = get_params(opt_state)

    # Selective Reg Bit
    const_flat_params, unflattener = ravel_pytree(params)
    important_params = grad(mahalanobis_dist_loss_and_ewc)(params, reg_grads, reg_pos, batch, const_flat_params)
    max_params = important_params.copy()
    for i in [0,2,4]:
        max_params[i] = tuple([jnp.ones(max_params[i][0].shape)*jnp.max(max_params[i][0]),\
                              jnp.ones(max_params[i][1].shape)*jnp.max(max_params[i][1])])
    l2_grads = grad(l2_regularizer)(params)

    # Loss and EWC Bit
    params_diff = tree_multimap(subtract_two_trees, params, reg_pos)
    const_flat_diff, unflattener = ravel_pytree(params_diff)
    loss_grads = grad(loss)(params, batch)
    elastic_grads = grad(riem_dist_ewc_from_reg_grad)(reg_grads, const_flat_diff)
    diag_grads = tree_multimap(ewc_reg_diag, params, reg_pos, diag_tree)

    full_grads = jax.tree_util.tree_multimap(add_two_trees_selective_l2_ewc_full_diag,\
                                    loss_grads, l2_grads, important_params, max_params, elastic_grads, diag_grads)
    return opt_update(epoch, full_grads, opt_state)

  # Function used to iteratively update the fisher info diagonal (since its mean gradients squared averaged over the dataset)
  # The code complexity is just because I use tree operations so that we don't have to flatten every time (saves a lot of time)
  # See tree_multimap docs for more info
  def update_mean_tree_wrap(m, full_grads, datum_idx):
    if ((datum_idx+1) % 1000) == 0:
        print(datum_idx+1)
    def update_mean_tree(old_mean, new_obs):
            datum_idxr = datum_idx + 1
            return old_mean + (1/(datum_idxr))*(new_obs**2 - old_mean)
    return tree_multimap(update_mean_tree, m, full_grads)

  # Uses the update_mean_tree function to calculate fisher information
  def get_diag_fisher(opt_state, diag_tree, reg_pos, batch, task_index):
    params = get_params(opt_state)
    _, m = init_zero_tree(rng, (-1, 28 * 28))
    _, V = init_zero_tree(rng, (-1, 28 * 28))
    diag_grads = tree_multimap(ewc_reg_diag_no_l_rate, params, reg_pos, diag_tree)
    for datum_idx in range(5000): #range(batch[0].shape[0]):
        loss_grads = grad(online_loss)(params, (batch[0][datum_idx], batch[1][datum_idx]))
        if task_index > 0:
            full_grads = tree_multimap(add_two_trees, loss_grads, diag_grads)
        else:
            full_grads = loss_grads
        m = update_mean_tree_wrap(m, full_grads, datum_idx)
    return m

  _, init_params = init_random_params(rng, (-1, 28 * 28))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  train_accs = np.zeros((num_tasks, num_tasks*num_epochs))
  test_accs = np.zeros((num_tasks, num_tasks*num_epochs))

  print("\nStarting training...")
  _, reg_pos = init_zero_tree(rng, (-1, 28 * 28))
  _, diag_tree = init_zero_tree(rng, (-1, 28 * 28))
  for task_idx in range(num_tasks): # For each task we train with sgd and ewc regularization
      batches = task_loaders[task_idx]
      for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
          if task_idx == 0: # If its the first task we just use sgd as normal
              opt_state = update_standard(next(itercount), opt_state, next(batches))
              #opt_state = update_selective(next(itercount), opt_state, next(batches))
          else:
              #opt_state = update_ewc_full_diag(next(itercount), opt_state, ewc_grads, reg_pos, diag_tree, next(batches))
              opt_state = update_selective_ewc_full_diag(next(itercount), opt_state, ewc_grads, reg_pos, diag_tree, next(batches))
        epoch_time = time.time() - start_time

        # Track train and test accuracy on every task at the end of each epoch
        params = get_params(opt_state)
        for i in range(num_tasks):
            train_accr = accuracy(params, (task_ims[i], task_labels[i]))
            test_accr = accuracy(params, (test_ims[i], test_labs[i]))
            train_accs[i, task_idx*num_epochs + epoch] = train_accr
            test_accs[i, task_idx*num_epochs + epoch] = test_accr
            if task_idx == i:
                train_acc = train_accr
                test_acc = test_accr
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))

      # Get fisher info and reg point used for next task
      # Note for the efficient Hessian-Vector method of getting the riemannian distance we just need to store the first derivative
      # of the network for the current task.
      params = get_params(opt_state)
      next_grads = grad(loss)(params, (task_ims[task_idx], task_labels[task_idx]))
      if task_idx > 0:
          params_diff = tree_multimap(subtract_two_trees, params, reg_pos)
          const_flat_diff, unflattener = ravel_pytree(params_diff)
          elastic_grads = grad(riem_dist_ewc_from_reg_grad)(ewc_grads, const_flat_diff)
          next_grads = tree_multimap(add_two_trees_ewc_full_no_step, next_grads, elastic_grads)
      if not task_idx == (num_tasks-1):
          diag_tree = get_diag_fisher(opt_state, diag_tree, reg_pos, (task_ims[task_idx], task_labels[task_idx]), task_idx)
      ewc_grads = next_grads
      reg_pos = params
  
  # Save logs and graph
  np.savetxt(str(run_idx)+'_train_accs.txt', train_accs)
  np.savetxt(str(run_idx)+'_test_accs.txt', test_accs)
  for i in range(num_tasks):
      plt.plot(train_accs[i])
      plt.title("Train Accuracy")
      plt.xlabel("epochs")
      plt.ylabel("Accuracy")
      plt.savefig("train_task" +str(i)+ ".png")
      plt.close()

      plt.plot(test_accs[i])
      plt.title("Test Accuracy")
      plt.xlabel("epochs")
      plt.ylabel("Accuracy")
      plt.savefig("test_task" +str(i)+ ".png")
      plt.close()
