"""
Considering the following arguments

Regularization:
input_dropout (float): Input dropout fraction.
layer_dropout (float): Layer dropout fraction.
l2_reg_strength (float): L2 regularization strength.
shift_images (bool): Whether to randomly shift images in training.

Model structure flags
num_towers (int): number of tower (aka ensemble components)
hidden_dims (list of int): Dimension of hidden layers.

Oprimization flags
num_epochs (int)
optimizer (string): One of "sgd", "adagrad".
learning_rate (float): learning rate for SGD or Adagrad.

# Anti-distillation flags
ad_type (string): One of "no_ad", "cor", "batch-cov".
ad_strength (float): Multiplier of AD loss.
"""

def build_tower(hidden_dims,
                input_layer,
                input_dropout=0,
                layer_dropout=0,
                l2_reg_strength=0,):
  """Builds a tower for the ensemble model.

  Args:
    hidden_dims: List of dimensions of hidden layers.
    input_layer: Input layer.
    input_dropout: Dropout fraction to apply to input layer.
    layer_dropout: Dropout fraction to apply to hidden layers.
    l2_reg_strength: L2 Regularization strength.

  Returns:
    The tower's dimension-10 logit layer and dimension-10 softmax layer.
  """
  net = input_layer
  if input_dropout > 0:
    net = tf.keras.layers.Dropout(input_dropout)(net)

  for i in range(len(hidden_dims)):
    reg = None
    if l2_reg_strength > 0:
      reg = tf.keras.regularizers.l2(l2_reg_strength)

    net = tf.keras.layers.Dense(hidden_dims[i],
                                activation='relu',
                                kernel_regularizer=reg)(net)
    if layer_dropout > 0:
      net = tf.keras.layers.Dropout(layer_dropout)(net)

  logit_layer = tf.keras.layers.Dense(10)(net)

  softmax_layer = tf.keras.layers.Activation('softmax')(logit_layer)

  return logit_layer, softmax_layer


def get_ensemble(hidden_dims,
                 num_towers,
                 anti_distillation_outputs,
                 input_dropout=0,
                 layer_dropout=0,
                 l2_reg_strength=0):
  """Builds a Sequential model to recognize MNIST digits.

  Args:
    hidden_dims: List of dimensions of hidden layers.
    num_towers: Numer of towers in the ensemble.
    anti_distillation_outputs: whether to have anti-distillation outputs.
    input_dropout: Dropout fraction to apply to input layer.
    layer_dropout: Dropout fraction to apply to hidden layers.
    l2_reg_strength: L2 Regularization strength.

  Returns:
    a Keras model used for MNIST
  """
  assert num_towers > 1

  net_input = tf.keras.layers.Input(shape=(28, 28))
  net = tf.keras.layers.Flatten(input_shape=(28, 28))(net_input)

  towers_logit_layers = []
  towers_softmax_layers = []
  for _ in range(num_towers):
    tower_logit, tower_softmax = build_tower(
        hidden_dims,
        input_layer=net,
        input_dropout=input_dropout,
        layer_dropout=layer_dropout,
        l2_reg_strength=l2_reg_strength)
    towers_logit_layers.append(tower_logit)
    towers_softmax_layers.append(tower_softmax)

  # Average prediction of the towers.
  layer_name = 'towers_avg_prediction'
  avg_prediction = (
      tf.keras.layers.Average(name=layer_name)(towers_softmax_layers))

  label_loss_out = (
      tf.keras.layers.concatenate(towers_softmax_layers, name=layer_name))

  if anti_distillation_outputs:
    layer_name = 'anti_distillation_logit_output'
    ad_out = tf.keras.layers.concatenate(towers_logit_layers, name=layer_name)
    model = tf.keras.Model(net_input,
                           tf.keras.layers.concatenate([avg_prediction,
                                                        label_loss_out,
                                                        ad_out],
                                                       name='output'))
  else:
    model = tf.keras.Model(net_input,
                           tf.keras.layers.concatenate([avg_prediction,
                                                        label_loss_out],
                                                       name='output'))
  return model


def random_shift(x, p_shift=0.5, max_pixel_shift=3):
  """Randomly shifs images in the dataset.

  Args:
    x: image dataset of shape (N, H, W).
    p_shift: probability to shift each image (independently of the others).
    max_pixel_shift: max pixel shift in each direction.

  Returns:
    a copy of the dataset, with shifted images.
  """
  xx = np.copy(x)
  for j in range(0, xx.shape[0]):
    if np.random.random() < p_shift:
      shift_amount = np.random.randint(-max_pixel_shift, max_pixel_shift, 2)
      xx[j, :, :] = shift(xx[j, :, :], shift_amount, order=0)
  return xx


def fit_model_with_shifted_images(model, xtrain, ytrain, num_epochs=15):
  """Fit a model to shifted-images train set.

  From the second epoch, each image is randomly shifted to up to 3 pixels in
  each direction.

  Args:
    model: the model to fit.
    xtrain: training images.
    ytrain: one-hot labels.
    num_epochs: numner of epochs.
  """
  for k in range(0, num_epochs):
    xtrain_tmp = np.copy(xtrain)
    if k > 0:
      xtrain_tmp = random_shift(xtrain_tmp)

    model.fit(
        xtrain_tmp,
        ytrain,
        epochs=1,
        validation_split=0)



def off_diag_moment(x):
  y = tf.expand_dims(x, 2) * tf.expand_dims(x, 1)
  z = tf.square(tf.reduce_mean(y, 0))
  return tf.reduce_sum(z) - tf.reduce_sum(tf.diag_part(z))


def correlation_loss(flattened_distillation_outputs,
                     num_towers,
                     num_classes,
                     batch_cov=False):
  """Correlation anti-distillation loss.

  Args:
    flattened_distillation_outputs: concatenation of anti-distillation heads.
    num_towers: bunmer of towers in the ensemble
    num_classes: number of classes
    batch_cov: it true batch-covariance loss

  Returns:
    Anti-distillation loss.
  """

  ad_loss = 0

  for i in range(num_classes):
    # Output for class i for all towers.
    ids = [i + t * num_classes for t in range(num_towers)]
    distillation_output = tf.gather(flattened_distillation_outputs,
                                    tf.constant(ids),
                                    axis=1)

    if batch_cov:
      average = tf.reduce_mean(distillation_output, 0)
      distillation_output = distillation_output - average
    # Add loss for class i.
    ad_loss += off_diag_moment(distillation_output)
  # Take the average loss over classes.
  ad_loss /= num_classes
  return ad_loss


def anti_distillation_loss(flattened_distillation_outputs,
                           num_towers,
                           num_classes,
                           ad_type='cor'):
  """Anti-distillation loss.

  Args:
    flattened_distillation_outputs: concatenation of anti-distillation heads.
    num_towers: bunmer of towers in the ensemble
    num_classes: number of classes
    ad_type: type of antidistillation

  Returns:
    Anti-distillation loss.
  """
  assert ad_type in ['cor', 'batch_cov']
  if ad_type == 'cor':
    return correlation_loss(flattened_distillation_outputs,
                            num_towers,
                            num_classes)
  if ad_type == 'batch_cov':
    return correlation_loss(flattened_distillation_outputs,
                            num_towers,
                            num_classes,
                            batch_cov=True)


def label_and_anti_distillation_loss(labels,
                                     model_output,
                                     anti_distillation,
                                     ad_type='cor',
                                     num_towers=None,
                                     ad_strength=0.0):
  """Label loss and anti-distillation loss.

  Args:
    labels: one hot true labels.
    model_output: concatenation ofprediction head and anti-distillation heads.
    anti_distillation: whether there is anti-distillation
    ad_type: type of antidistillation ("no_ad" meand no anti-distillation loss).
    num_towers: bunmer of towers in the ensemble
    ad_strength: multiplier of anti-distillation strength.

  Returns:
    Loss.
  """
  if anti_distillation:
    assert num_towers > 1

  num_classes = 10
  # model_predictions = model_output[:, :num_classes]

  loss = 0
  first_id = num_classes  # first num_classes ids are for avg prediction
  last_id = first_id + num_classes

  for _ in range(num_towers):
    loss = tf.keras.losses.categorical_crossentropy(
        labels, model_output[:, first_id : last_id])
    first_id = last_id
    last_id = first_id + num_classes

  if anti_distillation:
    ad_loss = anti_distillation_loss(model_output[:, first_id:],
                                     num_towers, num_classes,
                                     ad_type=ad_type)

    loss = loss + ad_strength * ad_loss

  return loss


def compile_model_with_anti_distillation_loss(model,
                                              anti_distillation=False,
                                              ad_type='cor',
                                              num_towers=None,
                                              ad_strength=0,
                                              optimizer='sgd',
                                              learning_rate=0.01):
  """Compiles an ensemble models with or without anti-distillation loss.

  Args:
    model: the model.
    anti_distillation: whether there is anti-distillation
    ad_type: type of antidistillation ("no_ad" meand no anti-distillation loss).
    num_towers: bunmer of towers in the ensemble
    ad_strength: multiplier of anti-distillation strength.
    optimizer: one of 'sgd' or 'adagrad'
    learning_rate: learning rate
  """
  assert(optimizer in ['sgd', 'adagrad'])
  if optimizer == 'sgd':
    opt = tf.keras.optimizers.SGD(lr=learning_rate, momentum=0.9)
  elif optimizer == 'adagrad':
    opt = tf.keras.optimizers.Adagrad(
        learning_rate=learning_rate, initial_accumulator_value=0.1,
        epsilon=1e-07)

  model.compile(
      loss=lambda labels, model_output: label_and_anti_distillation_loss(
          labels, model_output,
          anti_distillation=anti_distillation,
          ad_type=ad_type,
          num_towers=num_towers,
          ad_strength=ad_strength),
      optimizer=opt,
      )


def get_mnist_input_datasets():
  """Downloads the MNIST dataset and creates train and eval dataset objects.

  Returns:
    Train dataset and eval dataset. The dataset doesn't include batch dim.

  """
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

  x_train = x_train.astype('float32') / 255
  x_test = x_test.astype('float32') / 255

  # convert class vectors to binary class matrices
  y_train = tf.keras.utils.to_categorical(y_train, 10)
  y_test = tf.keras.utils.to_categorical(y_test, 10)

  return x_train, y_train, x_test, y_test


def train_ensemble(hidden_dims,
                   num_towers,
                   ad_type,
                   ad_strength=0.0,
                   input_dropout=0.0,
                   layer_dropout=0.0,
                   l2_reg_strength=0.0,
                   shift_images=False,
                   optimizer='sgd',
                   learning_rate=0.01,
                   num_epochs=20):
  """Trains a model to recognize MNIST digits.

  Args:
    hidden_dims: List of dimensions of hidden layers.
    num_towers: numer of towers in the ensemble.
    ad_type: type of antidistillation ("no_ad" meand no anti-distillation loss).
    ad_strength: multiplier for anti-distillation loss.
    input_dropout: List of dropout fractions to apply to input layer.
    layer_dropout: List of dropout fractions to apply to hidden layers.
    l2_reg_strength: L2 Regularization strength.
    shift_images: whether to randomly shift images in training (from 2nd epoch).
    optimizer: one of "sgd" and "adagrad"
    learning_rate: learning rate for sgd or adagrad
    num_epochs: number of traning epochs.
  """
  x_train, y_train, x_test, _ = get_mnist_input_datasets()

  anti_distillation = False if ad_type == 'no_ad' else True

  model = get_ensemble(hidden_dims,
                       num_towers,
                       anti_distillation,
                       input_dropout=input_dropout,
                       layer_dropout=layer_dropout,
                       l2_reg_strength=l2_reg_strength)

  compile_model_with_anti_distillation_loss(
      model,
      anti_distillation=anti_distillation,
      ad_type=ad_type,
      num_towers=num_towers,
      ad_strength=ad_strength,
      optimizer=optimizer,
      learning_rate=learning_rate)

  if shift_images:
    fit_model_with_shifted_images(model, x_train, y_train,
                                  num_epochs=num_epochs)
  else:
    model.fit(
        x_train, y_train,
        epochs=num_epochs,
        validation_split=0)

  pred_train = model.predict(x_train)[:, :10]
  pred_test = model.predict(x_test)[:, :10]
  return pred_train, pred_test