"""Training file for large LR experiments on Cloud."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from jax import lax
from absl import app
from absl import flags
from operator import mul, add

import jax.experimental.stax as stax
from jax.api import grad, jit, vmap, jacobian, pmap, jacfwd, jacrev
from jax.tree_util import tree_map, tree_reduce
import jax.numpy as np
import jax.random as random
import numpy as onp

from neural_tangents import stax as stax_nt

import data_util


"""## Wide Resnet"""

@stax_nt._layer
def Dense(out_dim, W_std=1., b_std=0., W_init=stax_nt._randn(1.0), b_init=stax_nt._randn(1.0),
          parameterization='ntk',unit_norm=False):

  parameterization = parameterization.lower()

  ntk_init_fn, _ = stax.Dense(out_dim, W_init, b_init)

  def standard_init_fn(rng, input_shape):
    output_shape, (W, b) = ntk_init_fn(rng, input_shape)
    return output_shape, (W * W_std / np.sqrt(input_shape[-1]), b * b_std)

  if parameterization == 'ntk':
    init_fn = ntk_init_fn
  elif parameterization == 'standard':
    init_fn = standard_init_fn
  else:
    raise ValueError('Parameterization not supported: %s' % parameterization)

  def apply_fn(params, inputs, **kwargs):
    W, b = params

    if unit_norm:
          W=W/np.linalg.norm(W,keepdims=True)

    if parameterization == 'ntk':
      norm = W_std / np.sqrt(inputs.shape[-1])
      return norm * np.dot(inputs, W) + b_std * b
    elif parameterization == 'standard':
      return np.dot(inputs, W) + b

  def kernel_fn(kernels):
    """Compute the transformed kernels after a dense layer."""
    var1, nngp, var2, ntk = \
        kernels.var1, kernels.nngp, kernels.var2, kernels.ntk

    def fc(x):
      return _affine(x, W_std, b_std)

    if parameterization == 'ntk':
      var1, nngp, var2 = map(fc, (var1, nngp, var2))
      if ntk is not None:
        ntk = nngp + W_std**2 * ntk
    elif parameterization == 'standard':
      input_width = kernels.shape1[1]
      if ntk is not None:
        ntk = input_width * nngp + 1. + W_std**2 * ntk
      var1, nngp, var2 = map(fc, (var1, nngp, var2))

    return kernels._replace(
        var1=var1, nngp=nngp, var2=var2, ntk=ntk, is_gaussian=True)

#   setattr(kernel_fn, _COVARIANCES_REQ,
#           {'marginal': M.OVER_ALL, 'cross': M.OVER_ALL})
  return init_fn, apply_fn, kernel_fn

def _tanh(x, **kwargs):
  return np.tanh(x)

@stax_nt._layer
def Tanh(do_backprop=False):
  return stax_nt._elementwise(_tanh, do_backprop=do_backprop)

nonlin_dict = {'Relu': stax_nt.Relu(), 'Erf': stax_nt.Erf(), 'Identity': stax_nt.Identity(),'Tanh': Tanh()}

# MODELS

def fc(config):
	"""Vanilla fully-connected model. config is a NamedTuple. Note that input is assumed 4D and a flatten layer occurs first."""
    #### Unit_norm
    #Dense=stax_nt.Dense
	parameterization = 'ntk' if config.NTK_norm else 'standard'

	layer = [Dense(config.width, config.w_std, config.b_std, parameterization=parameterization,unit_norm=config.unit_norm)]
    ### DEBUG
	if config.batch_norm:
		layer += [_batch_norm_internal(axis=(0,))]
	layer += [nonlin_dict[config.nonlinearity]]

	layers_lst = [stax_nt.Flatten()]
	layers_lst += layer*config.depth
	lastlay=Dense(config.output_dim, config.w_std, config.b_std, parameterization=parameterization)
	
	layers_lst += [lastlay]
	return stax_nt.serial(*layers_lst)




def _max_pool_internal(window_shape, padding):
  """Layer constructor for a stax.MaxPool layer with dummy kernel computation.
  Do not use kernels for architectures that include this function."""
  init_fn, apply_fn = stax.MaxPool(window_shape, padding=padding)
  kernel_fn = lambda kernels: kernels
  return init_fn, apply_fn, kernel_fn
 

# cnn keras wider
def cnn_real(config):
  """The keras model with adjustments to the widths."""
  parameterization = 'ntk' if config.NTK_norm else 'standard'
  if config.pooling=='avg':
	  pooling=stax_nt.AvgPool
  elif config.pooling=='max':
      pooling=_max_pool_internal
  else:
	  raise Exception('Pooling not defined.',config.pooling)
  if config.batch_norm:
	  nonlin_lay=[_batch_norm_internal(),nonlin_dict[config.nonlinearity]]
  else:
	  nonlin_lay=[nonlin_dict[config.nonlinearity]]

  layers_lst = [stax_nt.Conv(3*config.width, (3,3), padding='SAME', W_std=config.w_std, b_std = config.b_std, parameterization=parameterization)]
  layers_lst += nonlin_lay
  layers_lst += [stax_nt.Conv(3*config.width, (3,3), padding='VALID', W_std=config.w_std, b_std = config.b_std, parameterization=parameterization)]
  layers_lst += nonlin_lay
  layers_lst += [pooling((6,6), padding='VALID')]
  #######################
  layers_lst += [stax_nt.Flatten(), stax_nt.Dense(5*config.width, config.w_std,config.b_std, parameterization=parameterization)]
  layers_lst += [stax_nt.Dense(config.output_dim, config.w_std,config.b_std, parameterization=parameterization)]
  return stax_nt.serial(*layers_lst)


def cnn_real_XL(config):
  """Originates from the cnn keras model. Fixed a missing Relu and trying to make as wide as possible."""
  parameterization = 'ntk' if config.NTK_norm else 'standard'
  if config.pooling=='avg':
	  pooling=stax_nt.AvgPool
  elif config.pooling=='max':
      pooling=_max_pool_internal
  else:
	  raise Exception('Pooling not defined.',config.pooling)
  if config.batch_norm:
	  nonlin_lay=[_batch_norm_internal(),nonlin_dict[config.nonlinearity]]
  else:
	  nonlin_lay=[nonlin_dict[config.nonlinearity]]

  layers_lst = [stax_nt.Conv(300, (3,3), padding='SAME', W_std=config.w_std, b_std = config.b_std, parameterization=parameterization)]
  layers_lst += nonlin_lay
  layers_lst += [stax_nt.Conv(300, (3,3), padding='VALID', W_std=config.w_std, b_std = config.b_std, parameterization=parameterization)]
  layers_lst += nonlin_lay
  layers_lst += [pooling((6,6), padding='VALID')]
  #######################
  layers_lst += [stax_nt.Conv(300, (3,3), padding='SAME', W_std=config.w_std, b_std = config.b_std, parameterization=parameterization)]
  layers_lst += nonlin_lay
  layers_lst += [stax_nt.Conv(300, (3,3), padding='VALID', W_std=config.w_std, b_std = config.b_std, parameterization=parameterization)]
  layers_lst += nonlin_lay
  layers_lst += [pooling((6,6), padding='VALID')]
  #######################
  layers_lst += [stax_nt.Flatten(), stax_nt.Dense(500, config.w_std,config.b_std, parameterization=parameterization)]
  layers_lst += [stax_nt.Dense(config.output_dim, config.w_std,config.b_std, parameterization=parameterization)]
  return stax_nt.serial(*layers_lst)

def wrn_original(config):
	"""Based off of WideResnet from paper, with or without BatchNorm. 
	(Set config.wrn_block_size=3, config.wrn_widening_f=10 in that case).
	Uses default weight and bias init."""
	parameterization = 'ntk' if config.NTK_norm else 'standard'

	layers_lst = [stax_nt.Conv(16, (3,3), padding='SAME', parameterization=parameterization), 
		_wrn_group(config.wrn_block_size, int(16*config.wrn_widening_f), batch_norm=config.batch_norm, parameterization=parameterization),
		_wrn_group(config.wrn_block_size, int(32*config.wrn_widening_f),(2, 2), batch_norm=config.batch_norm, parameterization=parameterization),
		_wrn_group(config.wrn_block_size, int(64*config.wrn_widening_f),(2, 2), batch_norm=config.batch_norm, parameterization=parameterization)
		]
	if config.batch_norm:
		layers_lst += [_batch_norm_internal(), stax_nt.Relu()]
	layers_lst += [stax_nt.AvgPool((8,8)), stax_nt.Flatten(), stax_nt.Dense(config.output_dim, parameterization=parameterization)]
	return stax_nt.serial(*layers_lst)

# INTERNAL METHODS
def _wrn_group(num_blocks, channels, strides=(1,1), batch_norm=True, parameterization='ntk'):
	"""A WideResnet group."""
	blocks = []
	blocks += [_wrn_block(channels, strides, channel_mismatch=True, batch_norm=batch_norm, parameterization=parameterization)]
	for _ in range(num_blocks-1):
		blocks += [_wrn_block(channels, (1,1), batch_norm=batch_norm, parameterization=parameterization)]
	return stax_nt.serial(*blocks)

def _wrn_block(channels, strides=(1,1), channel_mismatch=False, batch_norm=True, parameterization='ntk'):
	"""A WideResnet block, with or without BatchNorm."""
	if batch_norm:
		Main = stax_nt.serial(
			_batch_norm_internal(), stax_nt.Relu(), stax_nt.Conv(channels, (3,3), strides, padding='SAME', parameterization=parameterization),
			_batch_norm_internal(), stax_nt.Relu(), stax_nt.Conv(channels, (3,3), padding='SAME', parameterization=parameterization))
	else:
		Main = stax_nt.serial(
			stax_nt.Relu(), stax_nt.Conv(channels, (3,3), strides, padding='SAME', parameterization=parameterization),
			stax_nt.Relu(), stax_nt.Conv(channels, (3,3), padding='SAME', parameterization=parameterization))
	Shortcut = stax_nt.Identity() if not channel_mismatch else stax_nt.Conv(channels, (3,3), strides, padding='SAME', parameterization=parameterization)
	return stax_nt.serial(stax_nt.FanOut(2), stax_nt.parallel(Main, Shortcut), stax_nt.FanInSum())

def _batch_norm_internal(axis=(0, 1, 2)):
	"""Layer constructor for a stax.BatchNorm layer with dummy kernel computation.
	Do not use kernels for architectures that include this function."""
    
	init_fn, apply_fn = stax.BatchNorm(axis)
    
	kernel_fn = lambda kernels: kernels
	return init_fn, apply_fn, kernel_fn

def Frozen(layer):
  init_fn, apply_fn, _ = layer

  def frozen_apply_fn(params, xs, **unused_kwargs):
    params = tree_map(lambda x: lax.stop_gradient(x), params)
    return apply_fn(params, xs)


  return init_fn, frozen_apply_fn, lambda kernels: kernels




# OTHER UTIL FUNCTIONS

def mse(params, data_tuple, func):
	"""MSE loss."""
	x, y = data_tuple
	return 0.5 * np.mean((y - func(params, x)) ** 2)

def L1(params, data_tuple, func):

    #### SUMMQRY: FOR OUTPUT DIM2, ABS OR SUBSTRACTION FOR WELL
    #### FOR OUTPUT DIM10,  FY-MAX(F) WORKs, abs value be more careful.
	"""MSE loss."""
	x, y = data_tuple
	from jax.nn import relu
    ## THIS WORKS WELL FOR ARBITRARY DIM. 
	return -np.mean(y*(func(params,x)-np.max(func(params,x),axis=1,keepdims=True)))


def cross_entropy(params, data_tuple, func):
	"""Cross-entropy loss."""
	x, y = data_tuple
	return -np.mean(stax.logsoftmax(func(params, x)) * y)

def cross_entropy(params, data_tuple, func,alpha=1):
	"""Cross-entropy loss."""
	x, y = data_tuple
	return -np.mean(stax.logsoftmax(alpha*func(params, x)) * y)

def count_parameters(params):
	"""Method to count number of parameters in model."""
	return tree_reduce(add, tree_map(lambda x: np.prod(x.shape), params))

def accuracy(y, y_hat):
	"""Y, Y_hat have shape (batch, # classes)."""
	target_class = np.argmax(y_hat, axis=1)
	predicted_class = np.argmax(y, axis=1)
	return np.mean(predicted_class == target_class)





