# Lint as: python3
# Copyright 2018 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""Utility functions for SPL."""

import numpy as np


def label_map_assignment(confusion_matrix, label_map, subclass_type):
  """Re-assign label map to define different subclass.

  This is the key of the proposed method.

  Args:
    confusion_matrix: nxn matrix where each column is preds of one class
    label_map: label map of each class in confusion matrix.
    subclass_type: different label_map reassigment strategies.

  Returns:
    reassigned label map
  """
  n = confusion_matrix.shape[1]
  if subclass_type == 'all':
    return label_map
  elif subclass_type == 'offdiag':
    # Assign 2n classes
    # put the off-diag to the rest n classes
    for c in range(n):
      label_map[:c, c] = n + c  # merge each column these n-2n classes
      label_map[c + 1:, c] = n + c  # merge each column these n-2n classes

  elif subclass_type == 'indiag':
    # Assign 2n classes
    # put the off-diag to the rest n classes
    for c in range(n):
      lm = label_map[:, c]
      cm_col = confusion_matrix.ravel()[lm].copy()
      cm_col[0] = -1  # assign diago
      sorted_id = np.argsort(cm_col)
      label_map[:, c] = c  # merge other class to class c
      label_map[sorted_id[-1],
                c] = n + c  # assign the biggest off-diag as a sub class
  else:
    raise NotImplementedError

  return label_map
