"""
Correspondence between PD names here and PD names in paper
N-way p=1: Delta_1
N-way p=2: Delta_2
N-way p=1 Normalized PD: Delta^L_1
N-way p=2 Normalized PD: Delta^L_1

Example eval
Assumes that 12 iteration (id from 0 to 11) were run for each config.
Given a base name such as base = 'anti_distillation/mnist/predictions/3_tower_128_64_32__no_ad'
and a run run_id, the predictions on the train and test sets were saved in:
base + '_' + str(run_id) + '_train'
base + '_' + str(run_id) + '_test'


input_dict = {
    'no_ad': 'anti_distillation/mnist/predictions/3_tower_128_64_32_\_no_ad'
    'cor': 'anti_distillation/mnist/predictions/3_tower_128_64_32_non_reg_ad_cor'
    'res-cor': 'anti_distillation/mnist/predictions/3_tower_128_64_32_ad_res_cor'
    'batch-cov': 'anti_distillation/mnist/predictions/3_tower_128_64_32_ad_batch_cov'
}

suffix_list = range(12)

metrics_file_name = 'xyz'
metrics = load_predictions_and_compute_average_metrics_train_test(input_dict, suffix_list,
                                                                  ytrain=y_train_one_hot, ytest=y_test_one_hot,
                                                                  return_metrics=True, save_metrics=True,
                                                                  metrics_file_name=metrics_file_name)
format_average_metrics_train_test(metrics
"""

def compute_accuracy(truth_one_hot, predictions):
  predicted_label = np.argmax(predictions ,axis=1)
  truth = np.argmax(truth_one_hot ,axis=1)
  return np.sum(predicted_label == truth) / predicted_label.shape[0]


def compute_errors(truth_one_hot, predictions):
  predicted_label = np.argmax(predictions[:,:10] ,axis=1)
  truth = np.argmax(truth_one_hot ,axis=1)
  return np.sum(predicted_label != truth)


def compute_logloss(truth_one_hot, predictions):
    truth_one_hot = truth_one_hot[:, :10]
    predictions = predictions[:, :10]
    return np.mean(K.eval(logloss(truth_one_hot, predictions)))


def compute_average_prediction(all_predictions):
  """
  Computes the average prediction for N models.

  Args:
    all_predictions: a list, where each element contains a model's predictons.
  """
  return np.mean(all_predictions, axis=0)


def n_way_pd(all_predictions, ord=1):
  """
  Computes PD for N models as the average of the Lp differences from the average
  prediction.

  Args:
    all_predictions: a list, where each element contains a model's predictons.
    ord: the p for Lp norm.
  """
  all_predictions_classes = []
  for predictions in all_predictions:
    all_predictions_classes.append(predictions[:,:10])

  # Compute the average prdiction
  average_prediction = compute_average_prediction(all_predictions_classes)

  pd = []
  for predictions in all_predictions_classes:
    pd.append(np.linalg.norm(average_prediction-predictions, ord=ord, axis=1))
  
  return np.mean(pd)


def get_predictions_on_true_class(all_predictions, labels):
  """
  Gets predictions on the true class only.

  Args:
    all_predictions: a list, where each element contains a model's predictons.
    labels: ground truth labels (assume to be one-hot labels)
  """
  truth = np.argmax(labels ,axis=1)
  all_predictions_on_true_class = []
  for predictions in all_predictions:
    all_predictions_on_true_class.append(predictions[np.arange(len(predictions)), truth])
  return all_predictions_on_true_class


def n_way_pd_normalized(all_predictions, labels, ord=1):
  """
  Computes PD for N models as the average of the Lp differences from the average
  prediction, normalized by the average prediction on the label of interest
  (the true label in multi-class).

  Args:
    all_predictions: a list, where each element contains a model's predictons.
    labels: ground truth labels (assume to be one-hot labels)
    ord: the p for Lp norm.
  """
  all_predictions_classes = []
  for predictions in all_predictions:
    all_predictions_classes.append(predictions[:,:10])

  all_predictions_on_true_class = (
      get_predictions_on_true_class(all_predictions_classes, labels)
  )
  # Compute the average prdiction
  average_prediction = compute_average_prediction(all_predictions_on_true_class)

  pd = []
  for predictions in all_predictions_on_true_class:
    pd.append(np.linalg.norm((average_prediction-predictions)/average_prediction, ord=ord) / len(average_prediction))
  
  return np.mean(pd)


def compute_metrics_from_prerdictions(predictions, metric, truth=None):

  metrics = {}
  for input_group in predictions:
    x = input_group
    suffix_list = range(32)
    if metric == 'Errors':
      assert(truth is not None)
      metrics[x] = [compute_errors(truth,  pred) for pred in predictions[x]]
    elif metric == 'Logloss':
      assert(truth is not None)
      metrics[x] = [compute_logloss(truth,  pred) for pred in predictions[x]]
    elif metric == 'N-way p=1 PD':
      metrics[x] = n_way_pd(predictions[x], ord=1)
    elif metric == 'N-way p=2 PD':
      metrics[x] = n_way_pd(predictions[x], ord=2)
    elif metric == 'N-way p=1 Normalized PD':
      assert(truth is not None)
      metrics[x] = n_way_pd_normalized(predictions[x], truth, ord=1)
    elif metric == 'N-way p=2 Normalized PD':
      assert(truth is not None)
      metrics[x] = n_way_pd_normalized(predictions[x], truth, ord=2)
  return metrics


def compute_average_metrics_train_test(ptrain, ptest,
                                       ytrain, ytest,
                                       return_metrics=True,
                                       save_metrics=False,
                                       metrics_file_name=None):
  errors_train =  compute_metrics_from_prerdictions(ptrain, 'Errors', truth=ytrain)
  errors_test  =  compute_metrics_from_prerdictions(ptest, 'Errors', truth=ytest)
  logloss_train =  compute_metrics_from_prerdictions(ptrain, 'Logloss', truth=ytrain)
  logloss_test  =  compute_metrics_from_prerdictions(ptest, 'Logloss', truth=ytest)
  p1_pd_train = compute_metrics_from_prerdictions(ptrain, 'N-way p=1 PD', truth=ytrain)
  p1_pd_test  = compute_metrics_from_prerdictions(ptest, 'N-way p=1 PD', truth=ytest)
  p2_pd_train = compute_metrics_from_prerdictions(ptrain, 'N-way p=2 PD', truth=ytrain)
  p2_pd_test  = compute_metrics_from_prerdictions(ptest, 'N-way p=2 PD', truth=ytest)
  p1_norm_pd_train = compute_metrics_from_prerdictions(ptrain, 'N-way p=1 Normalized PD', truth=ytrain)
  p1_norm_pd_test  = compute_metrics_from_prerdictions(ptest, 'N-way p=1 Normalized PD', truth=ytest)
  p2_norm_pd_train = compute_metrics_from_prerdictions(ptrain, 'N-way p=2 Normalized PD', truth=ytrain)
  p2_norm_pd_test  = compute_metrics_from_prerdictions(ptest, 'N-way p=2 Normalized PD', truth=ytest)

  metrics = {}
  i = 0
  for beta_config in errors_test:
    i += 1
    print('  computing for beta ' + str(i) + ' out of ' + str(len(errors_test)))
    x = beta_config
    metrics[x] = {}
    metrics[x]['errors_train'] = np.mean(errors_train[x])
    metrics[x]['errors_test']  = np.mean(errors_test[x])
    metrics[x]['logloss_train'] = np.mean(logloss_train[x])
    metrics[x]['logloss_test']  = np.mean(logloss_test[x])
    metrics[x]['p1_pd_train'] = np.mean(p1_pd_train[x])
    metrics[x]['p1_pd_test'] = np.mean(p1_pd_test[x])
    metrics[x]['p2_pd_train'] = np.mean(p2_pd_train[x])
    metrics[x]['p2_pd_test'] = np.mean(p2_pd_test[x])
    metrics[x]['p1_norm_pd_train'] = np.mean(p1_norm_pd_train[x])
    metrics[x]['p1_norm_pd_test'] = np.mean(p1_norm_pd_test[x])
    metrics[x]['p2_norm_pd_train'] = np.mean(p2_norm_pd_train[x])
    metrics[x]['p2_norm_pd_test'] = np.mean(p2_norm_pd_test[x])

  # your code to save metrics
  # if save_metrics:
  #   xxx

  if return_metrics:
    return metrics

def format_average_metrics_train_test(metrics):
  line = ' '.join(['model',
                   'logloss_train', 'errors_train',
                   'p1_pd_train', 'p2_pd_train', 'p1_norm_pd_train', 'p2_norm_pd_train',
                   'logloss_test', 'errors_test',
                   'p1_pd_test', 'p2_pd_test', 'p1_norm_pd_test', 'p2_norm_pd_test'])
  print(line)
  for x in metrics:

    errors_train = str(round(metrics[x]['errors_train']))
    errors_test  = str(round(metrics[x]['errors_test']))
    logloss_train = str(round(metrics[x]['logloss_train'],5))
    logloss_test  = str(round(metrics[x]['logloss_test'],5))
    p1_pd_train = str(round(metrics[x]['p1_pd_train'],5))
    p1_pd_test  = str(round(metrics[x]['p1_pd_test'],5))
    p2_pd_train = str(round(metrics[x]['p2_pd_train'],5))
    p2_pd_test  = str(round(metrics[x]['p2_pd_test'],5))
    p1_norm_pd_train = str(round(metrics[x]['p1_norm_pd_train'],5))
    p1_norm_pd_test  = str(round(metrics[x]['p1_norm_pd_test'],5))
    p2_norm_pd_train = str(round(metrics[x]['p2_norm_pd_train'],5))
    p2_norm_pd_test  = str(round(metrics[x]['p2_norm_pd_test'],5))
    line = ' '.join([x,
                     logloss_train, errors_train,
                     p1_pd_train, p2_pd_train, p1_norm_pd_train, p2_norm_pd_train,
                     logloss_test, errors_test,
                     p1_pd_test, p2_pd_test, p1_norm_pd_test, p2_norm_pd_test])
    print(line)

def load_all_predictions(input_dict, suffix_list):
  ptrain, ptest = {}, {}
  for input_group in input_dict:
    x = input_group
    ptrain[x] = []
    ptest[x]  = []
    base_paths = input_dict[x]
    # Implement this function based on your system.
    ptrain[x], ptest[x] = load_predictions(base_paths, suffix_list=suffix_list)
  return ptrain, ptest

def load_predictions_and_compute_average_metrics_train_test(input_dict, suffix_list,
                                                            ytrain, ytest,
                                                            return_metrics=True,
                                                            save_metrics=False,
                                                            metrics_file_name=None):  
  ptrain, ptest = load_all_predictions(input_dict, suffix_list=suffix_list)
  metrics = compute_average_metrics_train_test(ptrain, ptest, ytrain, ytest,
                                               return_metrics=True,
                                               save_metrics=save_metrics,
                                               metrics_file_name=metrics_file_name)
  if return_metrics:
    return metrics
