

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
import pdb
from problems import problem_spec as prob_spec

tf.app.flags.DEFINE_integer("hessian_itrs", 10,
                            """number of iterations for hessian information calculation.""")
tf.app.flags.DEFINE_string("reg_option", 'hessian',
                            """which regularization to use.""")
tf.app.flags.DEFINE_float("l2_reg_scale", 1e-3,
                          """Scaling factor for parameter value regularization
                             in softmax classifier problems.""")
FLAGS = tf.app.flags.FLAGS

EPSILON = 1e-6
MAX_SEED = 4294967295
PARAMETER_SCOPE = "parameters"

_Spec = prob_spec.Spec


class Problem(object):
  """Base class for optimization problems.

  This defines an interface for optimization problems, including objective and
  gradients functions and a feed_generator function that yields data to pass to
  feed_dict in tensorflow.

  Subclasses of Problem must (at the minimum) override the objective method,
  which computes the objective/loss/cost to minimize, and specify the desired
  shape of the parameters in a list in the param_shapes attribute.
  """

  def __init__(self, param_shapes, random_seed, noise_stdev, init_fn=None):
    """Initializes a global random seed for the problem.

    Args:
      param_shapes: A list of tuples defining the expected shapes of the
        parameters for this problem
      random_seed: Either an integer (or None, in which case the seed is
        randomly drawn)
      noise_stdev: Strength (standard deviation) of added gradient noise
      init_fn: A function taking a tf.Session object that is used to
        initialize the problem's variables.

    Raises:
      ValueError: If the random_seed is not an integer and not None
    """
    if random_seed is not None and not isinstance(random_seed, int):
      raise ValueError("random_seed must be an integer or None")

    # Pick a random seed.
    self.random_seed = (np.random.randint(MAX_SEED) if random_seed is None
                        else random_seed)

    # Store the noise level.
    self.noise_stdev = noise_stdev

    # Set the random seed to ensure any random data in the problem is the same.
    np.random.seed(self.random_seed)

    # Store the parameter shapes.
    self.param_shapes = param_shapes
    # MODIFY
    num_params = 0
    for shape in param_shapes:
      num = 1
      for n in shape:
        num *= n
      num_params += num
    self.num_params = num_params

    if init_fn is not None:
      self.init_fn = init_fn
    else:
      self.init_fn = lambda _: None

  def init_tensors(self, seed=None):
    """Returns a list of tensors with the given shape."""
    return [tf.random_normal(shape, seed=seed) for shape in self.param_shapes]

  def init_variables(self, seed=None):
    """Returns a list of variables with the given shape."""
    with tf.variable_scope(PARAMETER_SCOPE):
      params = [tf.Variable(param) for param in self.init_tensors(seed)]
    return params

  def objective(self, parameters, data=None, labels=None):
    """Computes the objective given a list of parameters.

    Args:
      parameters: The parameters to optimize (as a list of tensors)
      data: An optional batch of data for calculating objectives
      labels: An optional batch of corresponding labels

    Returns:
      A scalar tensor representing the objective value
    """
    raise NotImplementedError
  
  def jacob_trace(self, grads):
    """
    compute the jacob trace
    Args:
      grads: gradients of the problem

    Returns:
      jacob trace
    """
    grads = [tf.reshape(tf.square(grad), [1, -1]) for grad in grads]
    grads = tf.concat(grads, axis=1)
    jacob = tf.reduce_mean(grads)
    return jacob

  def group_add(self, params, update, alpha=1):
    """
    params = params + update*alpha
    :param params: list of variable
    :param update: list of data
    :return:
    """
    for i, p in enumerate(params):
      params[i] += (update[i] * alpha)
    return params

  def group_product(self, xs, ys):
    """
    the inner product of two lists of variables xs,ys
    :param xs:
    :param ys:
    :return:
    """
    return tf.reduce_sum(tf.stack([tf.reduce_sum(x * y) for (x, y) in zip(xs, ys)]))

  def normalization(self, v):
    """
    normalization of a list of vectors
    return: normalized vectors v
    """
    s = self.group_product(v, v)
    s = tf.sqrt(s)
    v = [vi / (s + 1e-6) for vi in v]
    return v

  def orthnormal(self, w, v_list):
    """
    make vector w orthogonal to each vector in v_list.
    afterwards, normalize the output w
    """
    for v in v_list:
      w = self.group_add(w, v, alpha=-self.group_product(w, v))
    return self.normalization(w)

  def hessain_trace(self, params, grads, max_iter=10, tol=1e-3):
    """
    compute the trace of hessian using Hutchinson's method
    maxIter: maximum iterations used to compute trace
    tol: the relative tolerance
    """
    # pdb.set_trace()
    trace_vhv = []
    with tf.variable_scope('hessian'):
      for i in range(max_iter):
        v = [
          np.random.randint(0, 2, tuple(p.get_shape().as_list()))
          for p in params
        ]
        # generate Rademacher random variables
        for v_i in v:
          v_i[v_i == 0] = -1
  
        v = [tf.constant(v_i, tf.float32) for v_i in v]
        # gradsH = [tf.matmul(tf.transpose(g), v_i) for g, v_i in zip(grads, v)]
        Hv = tf.gradients(grads, params, grad_ys=v)
        products = tf.reduce_sum(tf.stack([tf.reduce_sum(hv * v_i) for hv, v_i in zip(Hv, v)]))
        trace_vhv.append(products)
        # tf.cond(tf.abs(tf.reduce_mean(tf.stack(trace_vhv)) - trace) / (trace + 1e-6) < tol, lambda: return trace_vhv, )
        #   return trace_vhv
        # else:
        #   trace = tf.reduce_mean(tf.stack(trace_vhv))

    return tf.reduce_mean(tf.stack(trace_vhv))

  def hessian_eigenvalues(self, params, grads, max_iter=10, tol=1e-3, top_n=1):
    """
    compute the top_n eigenvalues using power iteration method
    maxIter: maximum iterations used to compute each single eigenvalue
    tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
    top_n: top top_n eigenvalues will be computed
    """

    assert top_n >= 1

    eigenvalues = []
    eigenvectors = []

    computed_dim = 0

    while computed_dim < top_n:
      eigenvalue = None
      v = [tf.random_normal(p.get_shape()) for p in params
           ]  # generate random vector
      v = self.normalization(v)  # normalize the vector

      for i in range(max_iter):
        v = self.orthnormal(v, eigenvectors)

        Hv = tf.gradients(grads, params, grad_ys=v)
        tmp_eigenvalue = self.group_product(Hv, v)

        v = self.normalization(Hv)

        eigenvalue = tmp_eigenvalue

      eigenvalues.append(eigenvalue)
      eigenvectors.append(v)
      computed_dim += 1

    return tf.reduce_sum(tf.stack(eigenvalues))

  def hessian_eigenvalues_nonneg(self, params, grads, max_iter=10, tol=1e-3, top_n=1):
    """
    compute the top_n eigenvalues using power iteration method
    maxIter: maximum iterations used to compute each single eigenvalue
    tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
    top_n: top top_n eigenvalues will be computed
    """

    assert top_n >= 1

    eigenvalues = []
    eigenvectors = []

    computed_dim = 0

    while computed_dim < top_n:
      eigenvalue = None
      v = [tf.random_normal(p.get_shape()) for p in params
           ]  # generate random vector
      v = self.normalization(v)  # normalize the vector

      for i in range(max_iter):
        v = self.orthnormal(v, eigenvectors)

        Hv = tf.gradients(grads, params, grad_ys=v)
        tmp_eigenvalue = self.group_product(Hv, v)

        v = self.normalization(Hv)

        eigenvalue = tmp_eigenvalue

      #gai le, jia le ge abs
      eigenvalue = tf.cond(eigenvalue > 0, lambda: eigenvalue, lambda: eigenvalue*0) 
      eigenvalues.append(eigenvalue)
      eigenvectors.append(v)
      computed_dim += 1

    return tf.reduce_sum(tf.stack(eigenvalues))

  def density(self, params, grads, max_iter=10, n_v=1):
    """
    compute estimated eigenvalue density using stochastic lanczos algorithm (SLQ)
    iter: number of iterations used to compute trace
    n_v: number of SLQ runs
    """

    eigen_list_full = []
    weight_list_full = []
    for k in range(n_v):
      v = [
        np.random.randint(0, 2, tuple(p.get_shape().as_list()))
        for p in params
      ]
      # generate Rademacher random variables
      for v_i in v:
        v_i[v_i == 0] = -1

      v = [tf.constant(v_i, tf.float32) for v_i in v]

      v = self.normalization(v)

      # standard lanczos algorithm initlization
      v_list = [v]
      w_list = []
      alpha_list = []
      beta_list = []
      ############### Lanczos
      for i in range(max_iter):
        w_prime = [tf.zeros_like(p.get_shape()) for p in params]
        if i == 0:
          w_prime = tf.gradients(grads, params, grad_ys=v)
          alpha = self.group_product(w_prime, v)
          alpha_list.append(alpha)
          w = self.group_add(w_prime, v, alpha=-alpha)
          w_list.append(w)
        else:
          beta = tf.sqrt(self.group_product(w, w))
          beta_list.append(beta)
          if beta_list[-1] != 0.:
            # We should re-orth it
            v = self.orthnormal(w, v_list)
            v_list.append(v)
          else:
            # generate a new vector
            w = [tf.random_normal(p.get_shape()) for p in params]
            v = self.orthnormal(w, v_list)
            v_list.append(v)
          w_prime = tf.gradients(grads, params, grad_ys=v)
          alpha = self.group_product(w_prime, v)
          alpha_list.append(alpha)
          w_tmp = self.group_add(w_prime, v, alpha=-alpha)
          w = self.group_add(w_tmp, v_list[-2], alpha=-beta)

      # T = tf.zeros([max_iter, max_iter])
      # pdb.set_trace()
      diag1 = tf.matrix_diag(alpha_list)
      # print(diag1)
      diag2 = tf.matrix_diag(beta_list[:len(alpha_list) - 1])
      # print(diag2)
      padding1 = tf.constant([[1, 0], [0, 1]])
      padding2 = tf.constant([[0, 1], [1, 0]])
      up_diag = tf.pad(diag2, padding1)
      # print(up_diag)
      down_diag = tf.pad(diag2, padding2)
      # print(down_diag)
      T = diag1 + up_diag + down_diag
      # for i in range(len(alpha_list)):
      #   T[i, i] = alpha_list[i]
      #   if i < len(alpha_list) - 1:
      #     T[i + 1, i] = beta_list[i]
      #     T[i, i + 1] = beta_list[i]
      a_, b_ = tf.self_adjoint_eig(T)

      eigen_list = a_
      # weight_list = b_[0, :] ** 2
      eigen_list_full.append(eigen_list)
      # weight_list_full.append(list(weight_list))

    return tf.reduce_sum(tf.stack(eigen_list_full))

  def gradients(self, objective, parameters):
    """Compute gradients of the objective with respect to the parameters.

    Args:
      objective: The objective op (e.g. output of self.objective())
      parameters: A list of tensors (the parameters to optimize)

    Returns:
      A list of tensors representing the gradient for each parameter,
        returned in the same order as the given list
    """
    grads = tf.gradients(objective, list(parameters))
    noisy_grads = []
    
    regularization = FLAGS.reg_option
    reg = None
    for grad in grads:
      if isinstance(grad, tf.IndexedSlices):
        noise = self.noise_stdev * tf.random_normal(tf.shape(grad.values))
        new_grad = tf.IndexedSlices(grad.values + noise, grad.indices)
      else:
        new_grad = grad + self.noise_stdev * tf.random_normal(grad.get_shape())
      noisy_grads.append(new_grad)
    # pdb.set_trace()
    
    print('Hessian iterations = ', FLAGS.hessian_itrs)

    if regularization == 'hessian-ev':
      print('use hessian ev regularization')
      reg = self.hessian_eigenvalues(parameters, grads, max_iter=FLAGS.hessian_itrs)

    elif regularization == 'hessian':
      print('use hessian trace regularization')
      reg = self.hessain_trace(parameters, grads, max_iter=FLAGS.hessian_itrs)

    elif regularization == 'hessian-esd':
      print('use hessian esd regularization')
      reg = self.density(parameters, grads, max_iter=FLAGS.hessian_itrs)

    elif regularization == 'jacob':
      print('use jacob regularization')
      reg = self.jacob_trace(grads)

    elif regularization == 'hessian-nonneg-ev':
      print('use hessian nonnegative ev regularization')
      reg = self.hessian_eigenvalues_nonneg(parameters, grads, max_iter=FLAGS.hessian_itrs)
    return noisy_grads, reg



class SoftmaxClassifier(Problem):
  """Helper functions for supervised softmax classification problems."""

  def init_tensors(self, seed=None):
    """Returns a list of tensors with the given shape."""
    return [tf.random_normal(shape, mean=0, stddev=0.01, seed=seed) * 1.2 / np.sqrt(shape[0])
            for shape in self.param_shapes]

  def inference(self, params, data):
    """Computes logits given parameters and data.

    Args:
      params: List of parameter tensors or variables
      data: Batch of features with samples along the first dimension

    Returns:
      logits: Un-normalized logits with shape (num_samples, num_classes)
    """
    raise NotImplementedError

  def objective(self, params, data, labels):
    """Computes the softmax cross entropy.

    Args:
      params: List of parameter tensors or variables
      data: Batch of features with samples along the first dimension
      labels: Vector of labels with the same number of samples as the data

    Returns:
      loss: Softmax cross entropy loss averaged over the samples in the batch

    Raises:
      ValueError: If the objective is to be computed over >2 classes, because
        this operation is broken in tensorflow at the moment.
    """


    # Forward pass.
    # pdb.set_trace()
    logits = self.inference(params, data)

    # Compute the loss.
    # l2reg = [tf.reduce_sum(param ** 2) for param in params]
    if int(logits.get_shape()[1]) == 2:
      labels = tf.cast(labels, tf.float32)
      losses = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=labels, logits=logits[:, 0])
    else:
      # labels = tf.squeeze(labels)

      losses = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)

    
    return tf.reduce_mean(losses)  # + tf.reduce_mean(l2reg) * FLAGS.l2_reg_scale

  def argmax(self, logits):
    """Samples the most likely class label given the logits.

    Args:
      logits: Un-normalized logits with shape (num_samples, num_classes)

    Returns:
      predictions: Predicted class labels, has shape (num_samples,)
    """
    return tf.cast(tf.argmax(tf.nn.softmax(logits), 1), tf.int32)

  def accuracy(self, params, data, labels):
    """Computes the accuracy (fraction of correct classifications).

    Args:
      params: List of parameter tensors or variables
      data: Batch of features with samples along the first dimension
      labels: Vector of labels with the same number of samples as the data

    Returns:
      accuracy: Fraction of correct classifications across the batch
    """
    predictions = self.argmax(self.inference(params, data))
    return tf.contrib.metrics.accuracy(predictions, tf.cast(labels, tf.int32))


class SoftmaxRegression(SoftmaxClassifier):
  """Builds a softmax regression problem."""

  def __init__(self, n_features, n_classes, activation=tf.identity,
               random_seed=None, noise_stdev=0.0):
    self.activation = activation
    self.n_features = n_features
    param_shapes = [(n_features, n_classes), (n_classes,)]
    super(SoftmaxRegression, self).__init__(param_shapes,
                                            random_seed,
                                            noise_stdev)

  def inference(self, params, data):
    features = tf.reshape(data, (-1, self.n_features))
    return tf.matmul(features, params[0]) + params[1]


class SparseSoftmaxRegression(SoftmaxClassifier):
  """Builds a sparse input softmax regression problem."""

  def __init__(self,
               n_features,
               n_classes,
               activation=tf.identity,
               random_seed=None,
               noise_stdev=0.0):
    self.activation = activation
    self.n_features = n_features
    param_shapes = [(n_classes, n_features), (n_features, n_classes), (
        n_classes,)]
    super(SparseSoftmaxRegression, self).__init__(param_shapes, random_seed,
                                                  noise_stdev)

  def inference(self, params, data):
    all_embeddings, softmax_weights, softmax_bias = params
    embeddings = tf.nn.embedding_lookup(all_embeddings, tf.cast(data, tf.int32))
    embeddings = tf.reduce_sum(embeddings, 1)
    return tf.matmul(embeddings, softmax_weights) + softmax_bias


class OneHotSparseSoftmaxRegression(SoftmaxClassifier):
  """Builds a sparse input softmax regression problem.

  This is identical to SparseSoftmaxRegression, but without using embedding
  ops.
  """

  def __init__(self,
               n_features,
               n_classes,
               activation=tf.identity,
               random_seed=None,
               noise_stdev=0.0):
    self.activation = activation
    self.n_features = n_features
    self.n_classes = n_classes
    param_shapes = [(n_classes, n_features), (n_features, n_classes), (
        n_classes,)]
    super(OneHotSparseSoftmaxRegression, self).__init__(param_shapes,
                                                        random_seed,
                                                        noise_stdev)

  def inference(self, params, data):
    all_embeddings, softmax_weights, softmax_bias = params
    num_ids = tf.shape(data)[1]
    one_hot_embeddings = tf.one_hot(tf.cast(data, tf.int32), self.n_classes)
    one_hot_embeddings = tf.reshape(one_hot_embeddings, [-1, self.n_classes])
    embeddings = tf.matmul(one_hot_embeddings, all_embeddings)
    embeddings = tf.reshape(embeddings, [-1, num_ids, self.n_features])
    embeddings = tf.reduce_sum(embeddings, 1)
    return tf.matmul(embeddings, softmax_weights) + softmax_bias


class FullyConnected(SoftmaxClassifier):
  """Builds a multi-layer perceptron classifier."""

  def __init__(self, n_features, n_classes, hidden_sizes=(32, 64),
               activation=tf.nn.sigmoid, random_seed=None, noise_stdev=0.0):
    """Initializes an multi-layer perceptron classification problem."""
    # Store the number of features and activation function.
    self.n_features = n_features
    self.activation = activation

    # Define the network as a list of weight + bias shapes for each layer.
    param_shapes = []
    for ix, sz in enumerate(hidden_sizes + (n_classes,)):

      # The previous layer"s size (n_features if input).
      prev_size = n_features if ix == 0 else hidden_sizes[ix - 1]

      # Weight shape for this layer.
      param_shapes.append((prev_size, sz))

      # Bias shape for this layer.
      param_shapes.append((sz,))

    super(FullyConnected, self).__init__(param_shapes, random_seed, noise_stdev)

  # MODIFY
  def init_tensors(self, seed=None):
    """Returns a list of tensors with the given shape."""
    return [tf.random_normal(shape, mean=0., stddev=0.01, seed=seed)
            for shape in self.param_shapes]

  def inference(self, params, data):
    # Flatten the features into a vector.
    features = tf.reshape(data, (-1, self.n_features))

    # Pass the data through the network.
    preactivations = tf.nn.bias_add(tf.matmul(features, params[0]), params[1])

    for layer in range(2, len(self.param_shapes), 2):
      net = self.activation(preactivations)
      preactivations = tf.nn.bias_add(tf.matmul(net, params[layer]), params[layer + 1])
      # preactivations = self.activation(preactivations)
    return preactivations

  def accuracy(self, params, data, labels):
    """Computes the accuracy (fraction of correct classifications).

    Args:
      params: List of parameter tensors or variables
      data: Batch of features with samples along the first dimension
      labels: Vector of labels with the same number of samples as the data

    Returns:
      accuracy: Fraction of correct classifications across the batch
    """
    predictions = self.argmax(self.activation(self.inference(params, data)))
    return tf.contrib.metrics.accuracy(predictions, tf.cast(labels, tf.int32))








class ConvNet(SoftmaxClassifier):
  """Builds an N-layer convnet for image classification."""

  def __init__(self,
               image_shape,
               n_classes,
               filter_list,
               activation=tf.nn.relu,
               random_seed=None,
               noise_stdev=0.0):
    # Number of channels, number of pixels in x- and y- dimensions.
    n_channels, px, py = image_shape

    # Store the activation.
    self.activation = activation

    param_shapes = []
    input_size = n_channels
    for fltr in filter_list:
      # Add conv2d filters.
      param_shapes.append((fltr[0], fltr[1], input_size, fltr[2]))
      param_shapes.append((fltr[2],))
      input_size = fltr[2]

    # Number of units in the final (dense) layer.
    self.affine_size = input_size * px * py

    param_shapes.append((self.affine_size, n_classes))  # affine weights
    param_shapes.append((n_classes,))  # affine bias

    super(ConvNet, self).__init__(param_shapes, random_seed, noise_stdev)

  def init_tensors(self, seed=None):
    """Returns a list of tensors with the given shape."""
    return [tf.random_normal(shape, mean=0., stddev=0.01, seed=seed)
            for shape in self.param_shapes]

  def inference(self, params, data):

    # Unpack.
    w_conv_list = params[:-2]
    output_w, output_b = params[-2:]

    conv_input = data
    # for w_conv in w_conv_list:
    for i in range(0, len(w_conv_list), 2):
      w_conv = w_conv_list[i]
      b_conv = w_conv_list[i+1]
      layer = tf.nn.conv2d(conv_input, w_conv, strides=[1] * 4, padding="SAME")
      layer = tf.nn.bias_add(layer, b_conv)
      output = self.activation(layer)
      conv_input = output

    # Flatten.
    flattened = tf.reshape(conv_input, (-1, self.affine_size))

    # Fully connected layer.
    # return tf.matmul(flattened, output_w) + output_b
    return self.activation(tf.nn.bias_add(tf.matmul(flattened, output_w), output_b))

