"""Training file for large LR experiments on Cloud."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
from absl import logging
from operator import mul, add
from functools import partial
from collections import namedtuple,OrderedDict

import math
import numpy as onp
import time
import sys
import yaml

import jax.random as random
from jax.api import grad, jit, vmap
from jax.tree_util import tree_map, tree_reduce
import jax.experimental.optimizers as optimizers
import jax.numpy as np
import jax
# from neural_tangents.utils.empirical import taylor_expand
import tempfile
import tensorflow as tf
import data_util
import model_util
import config
from meas_utils import *

flags.DEFINE_string('config_file', None, 'Load YAML with config')
flags.DEFINE_string('default_config','fc', 'Default config to load')

FLAGS = flags.FLAGS



# THIS FUNCTION GENERATES ALL MEASUREMENTS. IT IS SPECIFIC TO THESE EXPERIMENTS SO BETTER TO HAVE IT HERE?
def meas_function_gen(loss, predict, samples, C, s0=None,params_shape=None):

  # USES MINIBATCHER FROM DATA_UTILS
  'Measure function generator. Given a set of measurements tomeas, some of which have lists as outputs (multimeas contains that info). C is the config file which contains some info. '
  'Will measure by step'

  def accuracy(params, batch): return np.mean(
      np.argmax(predict(params, batch[0]), axis=1) == np.argmax(batch[1], axis=1))
  train_samples, test_samples = samples
  tomeas = {}
  multimeas = {}
  def measf(f): return meas_batch_generator(f, minibatcher(train_samples, batch_size=C.meas_bs), sample_size=C.meas_samples)
  def measft(f): return meas_batch_generator(f, minibatcher(test_samples, batch_size=C.meas_bs), sample_size=C.meas_samples)

  
  

  if C.std_meas:
    
    tomeas['learning_rate'] = C.learning_rate
    tomeas['batch_size'] = C.batch_size
    try:
      tomeas['width']=C.wrn_widening_f
    except:
      tomeas['width']=C.width
    num_samples=train_samples[0].shape[0]
    num_samplest=test_samples[0].shape[0]
    if C.std_meas_size>0:
      num_samples=min(num_samples,C.std_meas_size)
      num_samplest=min(num_samplest,C.std_meas_size)


    tomeas['train_loss'] = meas_batch_generator(loss, minibatcher(train_samples, batch_size=C.meas_bs), sample_size=num_samples)
    tomeas['test_loss'] = meas_batch_generator(loss, minibatcher(test_samples, batch_size=C.meas_bs), sample_size=num_samplest)
    tomeas['train_acc'] = meas_batch_generator(accuracy, minibatcher(train_samples, batch_size=C.meas_bs), sample_size=num_samples)
    tomeas['test_acc'] = meas_batch_generator(accuracy, minibatcher(test_samples, batch_size=C.meas_bs), sample_size=num_samplest)
    tomeas['logit']=lambda p: predict(p,train_samples[0][0:1])[0,1]
    
    from jax.tree_util import tree_reduce
    l2_norm = lambda params: tree_map(lambda x: np.sum(x ** 2), params)
    l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params))
    tomeas['weight_norm']=lambda p: l2_reg(p) 
  if C.meas_overlaps:
    flatHf = measf(Hv(loss))
    gradd = measf(gradfn(loss))
    tomeas['evalg'] = getHg_scalarsfn(gradd, flatHf)
    
  if C.L2:
    def mse(params, data_tuple):
      x, y = data_tuple
      return 0.5 * np.mean((y - predict(params, x)) ** 2)
    def cross_entropy(params, data_tuple):
      x, y = data_tuple
      import jax.experimental.stax as stax
      return -np.mean(stax.logsoftmax(predict(params, x)) * y)  

    if C.loss_type=='mse':  
      bare_loss=mse
    elif C.loss_type=='xent':
      bare_loss=cross_entropy
    else:
      raise ExceptionError('not defined')
    tomeas['train_bareloss'] = meas_batch_generator(bare_loss, minibatcher(train_samples, batch_size=C.meas_bs), sample_size=num_samples)  
  if C.measNTKspec:
    sizevntk = train_samples[1].shape[-1]*C.NTK_SAMPLES
    NTKv_general = NTKv(predict)
    if C.Hutch:
      tomeas['trNTK'] = trMvHutch(NTKv_general, minibatcher(train_samples, batch_size=C.NTK_SAMPLES), sizevntk, tol=C.Hutch_tol)
    else:
      tomeas['trNTK'] = measf(jit(trntk(predict)))
    tomeas['NTK_maxeval'] = compute_spectrumf(NTKv_general, minibatcher(train_samples, batch_size=C.NTK_SAMPLES), sizevntk, 1)
    multimeas['NTK_maxeval'] = ['NTK_max']
  
  if C.meashesseval:
      flatHf = Hv(loss)
      print(params_shape)
      tomeas['hess_maxeval']=compute_spectrumf(flatHf, minibatcher(train_samples, batch_size=C.meas_bs), params_shape, 1)
      multimeas['hess_maxeval'] = ['hess_max']

  if C.measHigherOrders:

    sample = train_samples[0][0:1]
    target = train_samples[1][0:1]

    tomeas['Theta'],tomeas['O3'],tomeas['O4'],tomeas['Theta_p'],tomeas['O3_p'],tomeas['O4_p']=generate_scalingops(predict,sample)

    tomeas['Delta_f'] = generate_Delta_f(predict, sample, target)



  path, name = get_jobpath(C,s0=str(s0))

  if C.verbose:
    logging.info('Saving directory: '+path)

  with open(path+name+'.yaml', 'w') as yaml_file:
    yaml.dump(dict(C._asdict()), yaml_file, default_flow_style=False)

  save_folder=path
  save_weight_list=[]
  if C.save_weights:
    save_weight_list = C.save_weights
    if C.savew_dir:
      save_folder=C.savew_dir
    if save_folder[-1]!='/':
      save_folder+='/'
    if not os.path.exists(save_folder):
      os.makedirs(save_folder)

  measurements = list(tomeas.keys())
  summary = []
  for el in measurements:
    if el not in multimeas:
      summary.append(el)
    else:
      summary += multimeas[el]

  summary = [C.time_unit]+summary
  prsummary = [C.time_unit] + \
      [el+',dt' for el in measurements[1:] if callable(tomeas[el])]
  log = []
 
  

  if C.savew_name:
    save_file0=save_folder+C.savew_name
  else:
    save_file0=save_folder+name+'model'
  

  def make_meas(time_step, params, log, force=False, name_suffix=None):
    'This function, called with the current time step, the current log and current parameters measures all the needed quantities.'
    if time_step == 0 and C.verbose:
      print('------------------------------------------')
      print(' \t'.join(prsummary))
      print('------------------------------------------')

    
    file_name = name

    if name_suffix is not None:
      file_name += name_suffix

    if time_step in save_weight_list or (force and -1 in save_weight_list):
      #save_file=save_folder+file_name+'model_step'+str(time_step)+'.pkl'
      save_file=save_file0+'_step'+str(time_step)+'.pkl'
      print('saving at',save_file )
      save_weights(save_file, time_step, params)
      if C.upload_to_cloud is not None:
        tf.io.gfile.copy(save_file,'gs://'+C.upload_to_cloud+'/'+save_file,overwrite=True)


    isnan=False
    meas_time_step=time_step
    meas_freq=C.meas_freq




      
  
    if meas_time_step % meas_freq != 0 and not force and (not time_step in save_weight_list):
      return log, isnan
    meas_list = [time_step]
    if C.verbose:
      print('{:}   '.format(time_step), end='\t')

    for el in measurements:
      if not callable(tomeas[el]):
        meas_list.append(tomeas[el])
        continue

      t0 = time.time()
      if not isnan:
        meas = tomeas[el](params)
      else:
        meas=np.nan
      t1 = time.time()
      
      if el in multimeas:
        if type(meas)==float:
          meas=np.ones(len(multimeas[el]))*meas
        meas_list += list(meas)
        meas=('%f' % float(meas[0])).rstrip('0')

      else:

        meas_list.append(meas)
        meas=('%f' % float(meas)).rstrip('0')

        if el=='train_loss' and  (onp.isnan(float(meas)) or onp.isinf(float(meas))):
          isnan=True
          break
        
        if el=='train_acc' and float(meas)>C.exit_acc:
          isnan=True

      if C.verbose:
        print('{},{:.2f}s'.format(meas, t1-t0), end='\t')
    if C.verbose:
      print('')
    log.append(np.array(meas_list))

    if (time_step % C.save_freq == 0 and C.save_freq!=-1) or force:
      if C.verbose:
        logging.info('Saving dataframe, step {}'.format(time_step))
      df = pd.DataFrame(log, columns=summary)
      
      df.to_csv(path+file_name+'.csv', index=False)

      if C.upload_to_cloud is not None:
        tf.io.gfile.copy(path+file_name+'.yaml','gs://'+C.upload_to_cloud+'/'+path+file_name+'.yaml',overwrite=True)
        tf.io.gfile.copy(path+file_name+'.csv','gs://'+C.upload_to_cloud+'/'+path+file_name+'.csv',overwrite=True)

    return log, isnan

  return make_meas, []


def get_config(base_config, **kwargs):
  """Load the configuration to initialize and run.

  Get config pulls default configurations from config then overwrites with
  values in kwargs.

  Args:
    base_config: String giving base config to use.
                 Allowed valued: [fc,  cnn_real, wrn_original].
    kwargs: parameters

  Returns:
    C: namedtuple containing configuration.
  """

  if base_config =='fc':
    C = config.get_default_fc_config()
  elif base_config == 'cnn_real':
    C = config.get_default_cnn_real_config()
  elif base_config == 'wrn_original':
    C = config.get_default_wrn_config()
  else:
    raise NotImplementedError('No configuration: %s' % base_config)

  _C_dict = C._asdict()

  for key in kwargs:
    val = kwargs[key]

    if key in C.alias:
      key = C.alias[key]

    if key in _C_dict:
      deftype=type(_C_dict[key])
      if key == 'save_weights':
        if val[0] == '[':
          val = val[1:-1]
        val = [int(float(el)) for el in val.split(',')]
        deftype = type(None)

      if deftype != type(None):
        if deftype == bool and type(val) != bool:
          deftype = eval
        if deftype == int:
          val = float(val)

        val = deftype(val)

      _C_dict[key] = val

    else:
      raise Exception('Invalid Argument: {}'.format(key))

  if _C_dict['n_train'] != 0:
    for el in ['batch_size','meas_samples','NTK_SAMPLES']:
      if _C_dict[el] > _C_dict['n_train']:
        _C_dict[el] = _C_dict['n_train']

  if _C_dict['meas_bs'] > _C_dict['meas_samples']:
    _C_dict['meas_bs'] = _C_dict['meas_samples']

  if _C_dict['upload_to_cloud']:
    _C_dict['data_dir'] = 'gs://' + _C_dict['upload_to_cloud'] + '/datasets'

  if _C_dict['meas_simple']:
    _C_dict['measHigherOrders'] = False
    _C_dict['measNTKspec'] = False
    _C_dict['meas_overlaps'] = False
  if 'XL' in _C_dict:
    if _C_dict['XL']:
      _C_dict['model']='cnn_real_XL'
  if _C_dict['physical']:
    _C_dict['meas_freq']=max(int(_C_dict['meas_freq']/_C_dict['learning_rate']),1)
    _C_dict['meas_freq']=min(_C_dict['meas_freq'],500)
    _C_dict['train_steps']=max(int(_C_dict['train_steps']/_C_dict['learning_rate']),1)
       
    
    _C_dict['save_freq']=max(int(_C_dict['save_freq']/10),1)*_C_dict['meas_freq']
  else:  

    _C_dict['train_steps']=int(_C_dict['train_steps'])
    _C_dict['meas_freq']=int(_C_dict['meas_freq'])


  logging.info('Evolving for {:}steps and measuring every {:}'.format(_C_dict['train_steps'],_C_dict['meas_freq']))
  C = namedtuple('config', _C_dict.keys())(**_C_dict)

  logging.info('Config:')
  for key in _C_dict:
    logging.info('\t%s: %s' % (key, _C_dict[key]))

  return C


def get_data(C):
  # No longer necessary to flatten images.
  train, test = data_util.load_and_process_dataset(C.dataset,output_dim=C.output_dim, n_train=C.n_train, n_test=C.n_test,data_dir=C.data_dir,permute_train=C.shuffledata,random_labels=C.randomlabels)

  if C.randomdata:
    axes = tuple(range(1, len(x.shape)))
    mean = onp.mean(train[0], axis=axes, keepdims=True)
    std_dev = onp.std(train[0], axis=axes, keepdims=True)
    
    train=mean+std_dev*onp.random.randn(*train[0].shape),train[1]
    test=mean+std_dev*onp.random.randn(*test[0].shape),test[1]

  return train, test


def get_model(C, input_shape):
  """Generate model from config.

  Produces tuple, (init_fn apply_fn, init_params) with the initializer; the
  function, which maps from params, data to outputs; and the initial parameters.
  Args:
    C: Config.
    input_shape: Input shape for data.

  Returns:
    init_fn: Initializer function.
    apply_fn: Model apply function.
    init_params: Initial parameters.
  """
  dict_of_models = {
      'fc': model_util.fc, 'cnn_real': model_util.cnn_real,'cnn_real_XL':model_util.cnn_real_XL,
      'wrn_original': model_util.wrn_original}

  model_fn = dict_of_models[C.model]

  # init_fn initializes the parameters; nonlin_fn is the model map, is a function of (params, input); kernel_fn output is ignored
  init_fn, nonlin_fn, _ = model_fn(C)
  if C.seed is not None:
    key = random.PRNGKey(C.seed)
  else:
    key = random.PRNGKey(int(time.process_time()))
  _, init_params = init_fn(key, (-1,) + input_shape)
  logging.info('Number of model parameters: {}'.format(
      model_util.count_parameters(init_params)))



  return init_fn, nonlin_fn, init_params


def get_optimizer(C, grad_loss, lr=None):
  """Generate the optimizer tuple.

  An optimizer is a triplet (init_func, update_func, get_params) which, resp.,
  initialize the opt state, update the opt state, and extract params from the
  opt state. Here we incorperate the loss dependence in update_func.

  Args:
    C: Config.
    grad_loss: grad_loss function
    lr: Optional overwrite for the learning rate in C.

  Returns:
    opt_init: Initialize optimizer state.
    opt_update_with_loss: Update the optimizer state.
    get_params: Return params from optimizer state.
  """


  if lr is None:
    lr = C.learning_rate


    # An optimizer is a triplet (init_func, update_func, get_params) which, resp., initialize the opt state, update the opt state, and extract params from the opt state
  def cosine(initial_step_size, train_steps):
    k = np.pi / (2.0 * train_steps)
    def schedule(i):
        return initial_step_size * np.cos(k * i)
    return schedule
  def std_wrn_sch(initial_step_size):
    def schedule(i):
        if i<23400:
            return initial_step_size
        elif i<46800:
            return initial_step_size*0.1
        elif i<62400:
            return initial_step_size*0.01
        else:
            return initial_step_size*0.001

    return schedule
  if C.wrnsch==True:
    lr_schedule = std_wrn_sch(C.learning_rate)
  elif C.momcos==True:
    lr_schedule = cosine(C.learning_rate, C.train_steps)
  else:
    lr_schedule = optimizers.make_schedule(C.learning_rate)


  if C.opt == 'sgd':
        opt_init, opt_update, get_params = optimizers.sgd(lr_schedule)
  elif C.opt == 'momentum':
        opt_init, opt_update, get_params = optimizers.momentum(lr_schedule, 0.9)
  else:
    raise NotImplementedError('Have not implemented optimizer: %s' % C.opt)

  def opt_update_with_loss(s, params, batch, opt_state):
    return opt_update(s, grad_loss(params, batch), opt_state)

  return opt_init, opt_update_with_loss, get_params


def get_loss(C, apply_fn,L2):
  """Get loss function.

  Args:
    C: Config.
    apply_fn: Model apply function.

  Returns:
    loss: Loss function.
  """
  dict_losses = {'xent': model_util.cross_entropy, 'mse': model_util.mse,'L1': model_util.L1}
  # full_loss, take in three arguments: (params, data_tuple, function)
  full_loss = dict_losses[C.loss_type]
  
  if C.linearize:
    def mse(params, data_tuple, p2, func):
      x, y = data_tuple
      return 0.5 * np.mean((y - func(params, x, p2)) ** 2)
    def cross_entropy(params, data_tuple, p2, func):
      x, y = data_tuple
      import jax.experimental.stax as stax
      return -np.mean(stax.logsoftmax(func(params, x,p2)) * y)  

    if C.loss_type=='mse':  
      full_loss=mse
    elif C.loss_type=='xent':
      full_loss=cross_entropy
    else:
      raise Exception('Loss not defined')

  if C.L2:
    from jax.tree_util import tree_reduce
    l2_norm = lambda params: tree_map(lambda x: np.sum(x ** 2), params)

    
    l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params))
    if C.linearize:
        full_lossb=lambda params,data_tuple,p2,func: full_loss(params, data_tuple, p2, func)+ L2 * l2_reg(params)
    else:
        full_lossb=lambda params,data_tuple,func: full_loss(params, data_tuple, func)+ L2 * l2_reg(params)
  else:
      full_lossb=full_loss

  loss = jit(partial(full_lossb, func=apply_fn))
  return loss


def initialize(C):
  """Initialize model, data, and optimizer.

  Args:
    C: Config.

  Returns:
    train: Training data.
    test: Test data.
    init_fn: Model init function.
    apply_fn: Model apply function.
    init_params: Initial parameters.
    s0: Initial step.
    batcher: Minibatch generator.
    loss: Loss function.
    optimizer: Optimizer tuple (init_func, update_func, get_params).
  """

  logging.info('Loading data.')
  train, test = get_data(C)

  logging.info('Batching data.')
  s0 = 0
  batcher = data_util.minibatcher(train, C.batch_size, C.data_seed, C.augment)

  input_shape = train[0].shape[1:]

  logging.info('Building model.')
  init_fn, nonlin_fn, init_params = get_model(C, input_shape)

  # Load from saved weights
  if C.load_weights:
    logging.info('Loading weights.')
    # discarded_params is used just for the structure to reconstruct pytree.
    discarded_params = init_params
    load_dir=C.load_weights
    if C.upload_to_cloud:
      load_dir=os.path.join(tempfile.gettempdir(), 'model-weights.pkl')
      to_load='gs://'+C.upload_to_cloud+'/'+C.load_weights
      tf.io.gfile.copy(to_load,load_dir,overwrite=True)
      logging.info("Loading weights from in "+to_load)

    s0, init_params = load_weights(load_dir, discarded_params)

    for s in range(s0):
      next(batcher)

  if C.linearize:

    def linearize(f):  
      def normalize(p):
          
          p[-4]=jax.tree_util.tree_map(lambda x: x/(np.linalg.norm(x)+1e-8),p[-4]) #np.linalg.norm(x,keepdims=True),pl)
          print(p[-4])
          return p#l+[p[-1]]

      def f_lin(p, x, params):

        params0=params

        dparams = jax.tree_util.tree_multimap(lambda x, y: x - y, p, params0)
        f_params_x, proj = jvp(lambda param: f(param, x), (params,), (dparams,))
        return f_params_x + proj
      return lambda p,x,p2: f_lin(p,x,p2)
    apply_fn = linearize(nonlin_fn)

  else:
    apply_fn = nonlin_fn

  logging.info('Building loss.')
  loss = get_loss(C, apply_fn,C.L2)

### RUDIMENTARY DISTRIBUTION, COULD BE SIMPLIFIED
  def dist(f, minibs=C.distgrad,samples=C.batch_size):

    full_batches, remain=samples // minibs, samples  % minibs
    mult=lambda x: x*remain/samples 
    sum2=lambda x,y: x+y*minibs/samples 
    print('Distributing measurement in',full_batches,'of ',minibs,' and one of',remain)

    def dist_grad(params,batch):
        x,y=batch
        if remain !=0:
          xs,ys=x[minibs*(full_batches):],y[minibs*(full_batches):]
        else:
          xs,ys=x[0:1],y[0:1]

        g=jax.tree_util.tree_multimap(mult,f(params,(xs,ys) ) )

        for i in range(full_batches):
          xs,ys=x[minibs*i:minibs*(i+1)],y[minibs*i:minibs*(i+1)]
          g=jax.tree_util.tree_multimap(sum2,g,f(params,(xs,ys) ) )


        return g

    return dist_grad

  
  if C.linearize:
      grad_loss=partial(jit(grad(loss)),p2=init_params)
      loss= partial(jit(loss),p2=init_params)
      apply_fn= partial(jit(apply_fn),p2=init_params)
  else:
      grad_loss = jit(grad(loss))

  if C.distgrad:
      grad_loss=dist(grad_loss)


  logging.info('Building optimizer.')
  optimizer = get_optimizer(C, grad_loss)

  return train, test, init_fn, apply_fn, init_params, s0, batcher, loss, optimizer


def fit(s_init, s_final, batcher, opt_state, opt_update, get_params, make_meas=None, log=[], metrics=None,verbose=0):
  """Perform training loop and measurements from s_init to s_final.

  Args:
    s_init: Initial step.
    s_final: Final step.
    batcher: Minibatch generator.
    opt_state: Optimizer state.
    opt_update: Update optimizer state.
    get_params: Return optimizer state params.
    make_meas:
    log:

  Returns:
    log:
  """
  _timing_list = []
  if s_init != 0:
        params = get_params(opt_state)

        log, _ = make_meas(s_init, params, log,force=True)
     
      

  for s in range(s_init, s_final + 1):

    t_start = time.time()

    params = get_params(opt_state)

    if make_meas is not None:
      log, isnan = make_meas(s, params, log)
      if isnan:
        logging.info('nan loss, done training')
        s = s_final
        break


        
    batch = next(batcher)
    opt_state = opt_update(s, params, batch, opt_state)

    _timing_list.append(time.time() - t_start)
    if len(_timing_list) == 10 and verbose==2:
      stat_string = 'step: %d' % s
      if metrics is not None:
        for key in metrics:
          mmeas=metrics[key](params, batch)
          stat_string += ', %s: %f' % (key,mmeas )
        
      stat_string += ', 10 step avg time: %fs' % onp.mean(_timing_list)
      logging.info(stat_string)
      _timing_list = []

  if make_meas is not None:
    log, _ = make_meas(s, params, log, force=True)

  logging.info('Training finished.')

  return log, opt_state, isnan 


def get_accuracy(apply_fn):
  def accuracy(params, batch):
    pred_label = np.argmax(apply_fn(params, batch[0]), axis=1)
    true_label = np.argmax(batch[1], axis=1)
    acc = np.mean(pred_label == true_label)
    return acc
  return accuracy


def main(unused_argv):

  dic = {}

  # Process args from command line
  if len(unused_argv)>1:
    if FLAGS.config_file:
      raise Exception('Specify only either args or yaml')

    for el in unused_argv[1:]:
      val = el.split('=')
      if len(val)!=2:
        raise Exception('Arguments should be given in the format arg1=valarg1')
      if val[0]=='model' and val[1]!='cnn_flaten':
        raise Exception('Do not specify models using args, use instead default_config!')

      dic[val[0]] = val[1]

  # Process args from yaml
  if FLAGS.config_file:
    with open(FLAGS.config_file) as f:
      dic = yaml.load(f, Loader=yaml.FullLoader)

  C = get_config(FLAGS.default_config, **dic)

  train, test, init_fn, apply_fn, init_params, s0, batcher, loss, optimizer = initialize(C)
  opt_init, opt_update, get_params = optimizer

  logging.info('Initializing optimizer state.')
  params = init_params
  opt_state = opt_init(params)
  
  
  logging.info('Setting up measurements.')
  make_meas, log = meas_function_gen(loss, apply_fn, (train, test), C=C, s0=s0,params_shape=flatten_jax(params).shape[0])

  logging.info('Setting up metrics.')
  loss_metric = loss
  accuracy = get_accuracy(apply_fn)
  acc_metric = accuracy
  metrics = {'batch_loss': loss_metric, 'batch_acc': acc_metric}

  logging.info('Training network.')

  log, opt_state,_ = fit(s0, C.train_steps, batcher, opt_state, opt_update, get_params, make_meas=make_meas, log=log, 
                         metrics=metrics,verbose=C.verbose)


if __name__ == '__main__':
  app.run(main)
