

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import numpy as np
import tensorflow as tf
import pdb
from problems import problem_spec as prob_spec

import uuid
import pickle
import itertools

tf.app.flags.DEFINE_float("l2_reg_scale", 1e-3,
                          """Scaling factor for parameter value regularization
                             in softmax classifier problems.""")

FLAGS = tf.app.flags.FLAGS

EPSILON = 1e-6
MAX_SEED = 4294967295
PARAMETER_SCOPE = "parameters"

_Spec = prob_spec.Spec

def expand_list(_len,x):
  assert len(x)<=_len, ("couldn't expand")
  space = int(round(_len/len(x)))
  left = _len-space*len(x)
  expand_x = []
  for i in range(len(x)-1):
    expand_x = expand_x + space*[x[i]]
  expand_x = expand_x + (space+left)*[x[len(x)-1]]
  return expand_x

class Problem(object):
  """Base class for optimization problems.

  This defines an interface for optimization problems, including objective and
  gradients functions and a feed_generator function that yields data to pass to
  feed_dict in tensorflow.

  Subclasses of Problem must (at the minimum) override the objective method,
  which computes the objective/loss/cost to minimize, and specify the desired
  shape of the parameters in a list in the param_shapes attribute.
  """

  def __init__(self, param_shapes, random_seed, noise_stdev, init_fn=None):
    """Initializes a global random seed for the problem.

    Args:
      param_shapes: A list of tuples defining the expected shapes of the
        parameters for this problem
      random_seed: Either an integer (or None, in which case the seed is
        randomly drawn)
      noise_stdev: Strength (standard deviation) of added gradient noise
      init_fn: A function taking a tf.Session object that is used to
        initialize the problem's variables.

    Raises:
      ValueError: If the random_seed is not an integer and not None
    """
    if random_seed is not None and not isinstance(random_seed, int):
      raise ValueError("random_seed must be an integer or None")

    # Pick a random seed.
    self.random_seed = (np.random.randint(MAX_SEED) if random_seed is None
                        else random_seed)

    # Store the noise level.
    self.noise_stdev = noise_stdev

    # Set the random seed to ensure any random data in the problem is the same.
    np.random.seed(self.random_seed)

    # Store the parameter shapes.
    self.param_shapes = param_shapes

    if init_fn is not None:
      self.init_fn = init_fn
    else:
      self.init_fn = lambda _: None

  def init_tensors(self, seed=None):
    """Returns a list of tensors with the given shape."""
    return [tf.random_normal(shape, seed=seed) for shape in self.param_shapes]

  def init_variables(self, seed=None, pretrained_model_path=None):
    """Returns a list of variables with the given shape."""
    with tf.variable_scope(PARAMETER_SCOPE):
      '''
      **************************************************************************
      '''
      params = [tf.Variable(param) for param in self.init_tensors(seed)]
      '''
      **************************************************************************
      '''
      #params = [tf.Variable(param) for param in self.init_tensors(seed, pretrained_model_path=pretrained_model_path)]
    return params

  def objective(self, parameters, data=None, labels=None):
    """Computes the objective given a list of parameters.

    Args:
      parameters: The parameters to optimize (as a list of tensors)
      data: An optional batch of data for calculating objectives
      labels: An optional batch of corresponding labels

    Returns:
      A scalar tensor representing the objective value
    """
    raise NotImplementedError

  def gradients(self, objective, parameters):
    """Compute gradients of the objective with respect to the parameters.

    Args:
      objective: The objective op (e.g. output of self.objective())
      parameters: A list of tensors (the parameters to optimize)

    Returns:
      A list of tensors representing the gradient for each parameter,
        returned in the same order as the given list
    """
    if self.random_sparse_method is None or self.random_sparse_method == "layer_wise":
      _random_sparse_prob = self.random_sparse_prob
    else:
      _random_sparse_prob = [1.0]

    def real_gradient(p):
      return tf.gradients(objective, parameter)[0]
    def fake_gradient(p):
      return tf.constant(0.0, shape=parameter.shape, dtype=tf.float32)

    parameters_list = list(parameters)
    grads = []
    grad_flag_list = []
    expand_random_sparse_prob = expand_list(len(parameters_list), self.random_sparse_prob)
    assert len(parameters_list) == len(expand_random_sparse_prob), ("Unsuccessful expand")
    for parameter, rd_ratio in zip(parameters_list, expand_random_sparse_prob):
      rd = tf.random.uniform(shape=[], maxval=1)
      grad_flag = tf.math.less_equal(rd, rd_ratio)
      grad_to_append = tf.cond(grad_flag, lambda: real_gradient(parameter), lambda: fake_gradient(parameter))
      grad_flag_list.append(grad_flag)
      grads.append(grad_to_append)

    noisy_grads = []

    for grad in grads:
      if isinstance(grad, tf.IndexedSlices):
        noise = self.noise_stdev * tf.random_normal(tf.shape(grad.values))
        new_grad = tf.IndexedSlices(grad.values + noise, grad.indices)
      else:
        new_grad = grad + self.noise_stdev * tf.random_normal(grad.get_shape())
      noisy_grads.append(new_grad)

    return noisy_grads, grad_flag_list




class SoftmaxClassifier(Problem):
  """Helper functions for supervised softmax classification problems."""

  def init_tensors(self, seed=None):
    """Returns a list of tensors with the given shape."""
    return [tf.random_normal(shape, mean=0, stddev=0.01, seed=seed) * 1.2 / np.sqrt(shape[0])
            for shape in self.param_shapes]

  def inference(self, params, data):
    """Computes logits given parameters and data.

    Args:
      params: List of parameter tensors or variables
      data: Batch of features with samples along the first dimension

    Returns:
      logits: Un-normalized logits with shape (num_samples, num_classes)
    """
    raise NotImplementedError

  def objective(self, params, data, labels, is_training=tf.squeeze(tf.constant([True], dtype=tf.bool))):
    """Computes the softmax cross entropy.

    Args:
      params: List of parameter tensors or variables
      data: Batch of features with samples along the first dimension
      labels: Vector of labels with the same number of samples as the data

    Returns:
      loss: Softmax cross entropy loss averaged over the samples in the batch

    Raises:
      ValueError: If the objective is to be computed over >2 classes, because
        this operation is broken in tensorflow at the moment.
    """
    # Forward pass.
    # pdb.set_trace()
    try:
      logits = self.inference(params, data, is_training=is_training)
    except:
      logits = self.inference(params, data)
    # print(logits.get_shape)
    # Compute the loss.
    l2reg = [tf.reduce_sum(param ** 2) for param in params]
    if int(logits.get_shape()[1]) == 2:
      labels = tf.cast(labels, tf.float32)
      losses = tf.nn.sigmoid_cross_entropy_with_logits(
          labels=labels, logits=logits[:, 0])
    else:
      losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits)
      # from keras.utils import np_utils
      # labels = tf.transpose(labels)
      # # labels = np_utils.to_categorical(labels)
      # losses = tf.nn.softmax_cross_entropy_with_logits(
      #   labels=labels, logits=logits)
      # raise ValueError("Unable to compute softmax cross entropy for more than"
      #                  " 2 classes.")

    return tf.reduce_mean(losses) + tf.reduce_mean(l2reg) * FLAGS.l2_reg_scale

  def argmax(self, logits):
    """Samples the most likely class label given the logits.

    Args:
      logits: Un-normalized logits with shape (num_samples, num_classes)

    Returns:
      predictions: Predicted class labels, has shape (num_samples,)
    """
    return tf.cast(tf.argmax(tf.nn.softmax(logits), 1), tf.int32)

  def accuracy(self, params, data, labels):
    """Computes the accuracy (fraction of correct classifications).

    Args:
      params: List of parameter tensors or variables
      data: Batch of features with samples along the first dimension
      labels: Vector of labels with the same number of samples as the data

    Returns:
      accuracy: Fraction of correct classifications across the batch
    """
    predictions = self.argmax(self.inference(params, data))
    return tf.contrib.metrics.accuracy(predictions, tf.cast(labels, tf.int32))

def lstm_func(x, h, c, wx, wh, b):
    """
        x: (N, D)
        h: (N, H)
        c: (N, H)
        wx: (D, 4H)
        wh: (H, 4H)
        b: (4H, )
    """
    N, H = tf.shape(h)[0], tf.shape(h)[1]
    a = tf.reshape(tf.matmul(x, wx) + tf.matmul(h, wh) + b, (N, -1, H))
    i, f, o, g = a[:, 0, :], a[:, 1, :], a[:, 2, :], a[:, 3, :]
    i = tf.sigmoid(i)
    f = tf.sigmoid(f)
    o = tf.sigmoid(o)
    g = tf.tanh(g)
    next_c = f * c + i * g
    next_h = o * tf.tanh(next_c)
    return next_h, next_c




class SoftmaxRegression(SoftmaxClassifier):
  """Builds a softmax regression problem."""

  def __init__(self, n_features, n_classes, activation=tf.identity,
               random_seed=None, noise_stdev=0.0):
    self.activation = activation
    self.n_features = n_features
    param_shapes = [(n_features, n_classes), (n_classes,)]
    super(SoftmaxRegression, self).__init__(param_shapes,
                                            random_seed,
                                            noise_stdev)

  def inference(self, params, data):
    features = tf.reshape(data, (-1, self.n_features))
    return tf.matmul(features, params[0]) + params[1]


class SparseSoftmaxRegression(SoftmaxClassifier):
  """Builds a sparse input softmax regression problem."""

  def __init__(self,
               n_features,
               n_classes,
               activation=tf.identity,
               random_seed=None,
               noise_stdev=0.0):
    self.activation = activation
    self.n_features = n_features
    param_shapes = [(n_classes, n_features), (n_features, n_classes), (
        n_classes,)]
    super(SparseSoftmaxRegression, self).__init__(param_shapes, random_seed,
                                                  noise_stdev)

  def inference(self, params, data):
    all_embeddings, softmax_weights, softmax_bias = params
    embeddings = tf.nn.embedding_lookup(all_embeddings, tf.cast(data, tf.int32))
    embeddings = tf.reduce_sum(embeddings, 1)
    return tf.matmul(embeddings, softmax_weights) + softmax_bias


class OneHotSparseSoftmaxRegression(SoftmaxClassifier):
  """Builds a sparse input softmax regression problem.

  This is identical to SparseSoftmaxRegression, but without using embedding
  ops.
  """

  def __init__(self,
               n_features,
               n_classes,
               activation=tf.identity,
               random_seed=None,
               noise_stdev=0.0):
    self.activation = activation
    self.n_features = n_features
    self.n_classes = n_classes
    param_shapes = [(n_classes, n_features), (n_features, n_classes), (
        n_classes,)]
    super(OneHotSparseSoftmaxRegression, self).__init__(param_shapes,
                                                        random_seed,
                                                        noise_stdev)

  def inference(self, params, data):
    all_embeddings, softmax_weights, softmax_bias = params
    num_ids = tf.shape(data)[1]
    one_hot_embeddings = tf.one_hot(tf.cast(data, tf.int32), self.n_classes)
    one_hot_embeddings = tf.reshape(one_hot_embeddings, [-1, self.n_classes])
    embeddings = tf.matmul(one_hot_embeddings, all_embeddings)
    embeddings = tf.reshape(embeddings, [-1, num_ids, self.n_features])
    embeddings = tf.reduce_sum(embeddings, 1)
    return tf.matmul(embeddings, softmax_weights) + softmax_bias


class FullyConnected(SoftmaxClassifier):
  """Builds a multi-layer perceptron classifier."""

  def __init__(self, n_features, n_classes, hidden_sizes=(32, 64),
               activation=tf.nn.sigmoid, random_seed=None, noise_stdev=0.0):
    """Initializes an multi-layer perceptron classification problem."""
    # Store the number of features and activation function.
    self.n_features = n_features
    print('n_features = ', n_features)
    self.activation = activation
    '''
    ***********************************************************
    '''
    self.random_sparse_method = None
    self.random_sparse_prob=[1.0]
    '''
    ***********************************************************
    '''
    # Define the network as a list of weight + bias shapes for each layer.
    param_shapes = []
    for ix, sz in enumerate(hidden_sizes + (n_classes,)):

      # The previous layer"s size (n_features if input).
      prev_size = n_features if ix == 0 else hidden_sizes[ix - 1]

      # Weight shape for this layer.
      param_shapes.append((prev_size, sz))

      # Bias shape for this layer.
      param_shapes.append((sz,))
    super(FullyConnected, self).__init__(param_shapes, random_seed, noise_stdev)

  def inference(self, params, data):
    # Flatten the features into a vector.
    features = tf.reshape(data, (-1, self.n_features))
    # Pass the data through the network.
    preactivations = tf.nn.bias_add(tf.matmul(features, params[0]), params[1])
    for layer in range(2, len(self.param_shapes), 2):
      net = self.activation(preactivations)
      preactivations = tf.nn.bias_add(tf.matmul(net, params[layer]), params[layer + 1])
      preactivations = self.activation(preactivations)
    
    return preactivations

  def accuracy(self, params, data, labels):
    """Computes the accuracy (fraction of correct classifications).

    Args:
      params: List of parameter tensors or variables
      data: Batch of features with samples along the first dimension
      labels: Vector of labels with the same number of samples as the data

    Returns:
      accuracy: Fraction of correct classifications across the batch
    """
    predictions = self.argmax(self.activation(self.inference(params, data)))
    return tf.contrib.metrics.accuracy(predictions, tf.cast(labels, tf.int32))

class MLP(SoftmaxClassifier):
  def __init__(self,
               input_dim,
               n_classes,
               hidden_layer_size_list,
               activation=tf.nn.relu,
               random_seed=None,
               noise_stdev=0.0,
               random_sparse_method=None,
               random_sparse_prob=[1.0]):

    # Store the activation.
    self.activation = activation
    self.random_sparse_method = random_sparse_method
    self.random_sparse_prob = random_sparse_prob
    param_shapes = []
    input_size = input_dim
    for hidden_layer_size in hidden_layer_size_list:
      # Add FC layer filters.
      param_shapes.append((input_size, hidden_layer_size))
      param_shapes.append((hidden_layer_size,))
      input_size = hidden_layer_size

    # the final FC before softmax
    param_shapes.append((input_size, n_classes))  # affine weights
    param_shapes.append((n_classes,))  # affine bias

    super(MLP, self).__init__(param_shapes, random_seed, noise_stdev)

  def init_tensors(self, seed=None, pretrained_model_path=None):
    """Returns a list of tensors with the given shape."""
    # load the pretrained model first
    if pretrained_model_path is not None:
      init_params = [tf.random_normal(shape, mean=0., stddev=0.01, seed=seed)
              for shape in self.param_shapes]
      with open(pretrained_model_path, "rb") as params_file:
        pretrained_params = pickle.load(params_file)
      for k_id, k in enumerate(pretrained_params):
        # only load before the last FC layer
        if k.shape[-1] != 5 or (k_id < len(pretrained_params) - 2):
          init_params[k_id] = k
          print("Loading weight shape:", k.shape)
        else:
          print("Not loading weight shape:", k.shape)
      return init_params
    else:
      return [tf.random_normal(shape, mean=0., stddev=0.01, seed=seed)
              for shape in self.param_shapes]

  def inference(self, params, data):

    # Unpack.
    pre_FC_list = params[:-2]
    output_w, output_b = params[-2:]

    MLP_input = data
    # for w_conv in pre_FC_list:
    for i in range(0, len(pre_FC_list), 2):
      w_FC = pre_FC_list[i]
      b_FC = pre_FC_list[i+1]
      output = self.activation(tf.nn.bias_add(tf.matmul(MLP_input, w_FC), b_FC))
      MLP_input = output

    # Fully connected layer.
    # return tf.matmul(flattened, output_w) + output_b
    result = self.activation(tf.nn.bias_add(tf.matmul(MLP_input, output_w), output_b))
    print(result[1:50])
    return result




class ConvNet(SoftmaxClassifier):
  """Builds an N-layer convnet for image classification."""

  def __init__(self,
               image_shape,
               n_classes,
               filter_list,
               activation=tf.nn.relu,
               random_seed=None,
               noise_stdev=0.0,
               random_sparse_method=None,
               random_sparse_prob=[1.0]):
    # Number of channels, number of pixels in x- and y- dimensions.
    n_channels, px, py = image_shape

    # Store the activation.
    self.activation = activation
    self.random_sparse_method = random_sparse_method
    self.random_sparse_prob = random_sparse_prob
    param_shapes = []
    input_size = n_channels
    for fltr in filter_list:
      # Add conv2d filters.
      param_shapes.append((fltr[0], fltr[1], input_size, fltr[2]))
      param_shapes.append((fltr[2],))
      input_size = fltr[2]

    # Number of units in the final (dense) layer.
    self.affine_size = int(input_size * px * py)

    param_shapes.append((self.affine_size, n_classes))  # affine weights
    param_shapes.append((n_classes,))  # affine bias

    super(ConvNet, self).__init__(param_shapes, random_seed, noise_stdev)

  def init_tensors(self, seed=None, pretrained_model_path=None):
    """Returns a list of tensors with the given shape."""
    # load the pretrained model first
    if pretrained_model_path is not None:
      init_params = [tf.random_normal(shape, mean=0., stddev=0.01, seed=seed)
              for shape in self.param_shapes]
      with open(pretrained_model_path, "rb") as params_file:
        pretrained_params = pickle.load(params_file)
      for k_id, k in enumerate(pretrained_params):
        # only load before the last FC layer
        if k.shape[-1] != 5 or (k_id < len(pretrained_params) - 2):
          init_params[k_id] = k
          print("Loading weight shape:", k.shape)
        else:
          print("Not loading weight shape:", k.shape)
      return init_params
    else:
      return [tf.random_normal(shape, mean=0., stddev=0.01, seed=seed)
              for shape in self.param_shapes]

  def inference(self, params, data):

    # Unpack.
    w_conv_list = params[:-2]
    output_w, output_b = params[-2:]

    conv_input = data
    # for w_conv in w_conv_list:
    for i in range(0, len(w_conv_list), 2):
      w_conv = w_conv_list[i]
      b_conv = w_conv_list[i+1]
      layer = tf.nn.conv2d(conv_input, w_conv, strides=[1] * 4, padding="SAME")
      layer = tf.nn.bias_add(layer, b_conv)
      output = self.activation(layer)
      conv_input = output

    # Flatten.
    flattened = tf.reshape(conv_input, (-1, self.affine_size))

    # Fully connected layer.
    # return tf.matmul(flattened, output_w) + output_b
    return self.activation(tf.nn.bias_add(tf.matmul(flattened, output_w), output_b))


