import numpy as np

class MahaDistNormalizer:
    def __init__(self,):
        super(MahaDistNormalizer, self).__init__()
        self.min, self.max = 1e20, -1e10

    def run(self, x, left, right):
        self.min = min(x.min(), self.min)
        self.max = max(x.max(), self.max)
        k = (right-left)/(self.max - self.min)
        return left+k*(x - self.min)

# class MahaDistNormalizer:
#     def __init__(self,):
#         super(MahaDistNormalizer, self).__init__()
#         self.min, self.max = 1e20, -1e10

#     def run(self, x):
#         self.min = min(x.min(), self.min)
#         self.max = max(x.max(), self.max)
#         return (x - self.min) / (self.max - self.min)

def maha_distance(xs,cov_inv_in,mean_in,norm_type=None):
  diffs = xs - mean_in.reshape([1,-1])
#   print(cov_inv_in.shape,mean_in.shape,diffs.shape)

  second_powers = np.matmul(diffs,cov_inv_in)*diffs
#   print(second_powers.shape)

  if norm_type in [None,"L2"]:
    return np.sum(second_powers,axis=1)
  elif norm_type in ["L1"]:
    return np.sum(np.sqrt(np.abs(second_powers)),axis=1)
  elif norm_type in ["Linfty"]:
    return np.max(second_powers,axis=1)


def maha(
    indist_train_embeds_in,
    indist_train_labels_in,
    subtract_mean = False,
    normalize_to_unity = False,
    indist_classes = 100,
    ):
  
  # storing the replication results
  maha_intermediate_dict = dict()
  
  description = ""
  
  all_train_mean = np.mean(indist_train_embeds_in,axis=0,keepdims=True)

  indist_train_embeds_in_touse = indist_train_embeds_in

  if subtract_mean:
    indist_train_embeds_in_touse -= all_train_mean
    description = description+" subtract mean,"

  if normalize_to_unity:
    indist_train_embeds_in_touse = indist_train_embeds_in_touse / np.linalg.norm(indist_train_embeds_in_touse,axis=1,keepdims=True)
    description = description+" unit norm,"

  #full train single fit
  mean = np.mean(indist_train_embeds_in_touse,axis=0)
  cov = np.cov((indist_train_embeds_in_touse-(mean.reshape([1,-1]))).T)

  eps = 1e-8
  cov_inv = np.linalg.inv(cov)

  #getting per class means and covariances
  class_means = []
  class_cov_invs = []
  class_covs = []
  for c in range(indist_classes):

    mean_now = np.mean(indist_train_embeds_in_touse[indist_train_labels_in == c],axis=0)

    cov_now = np.cov((indist_train_embeds_in_touse[indist_train_labels_in == c]-(mean_now.reshape([1,-1]))).T)
    class_covs.append(cov_now)
    # print(c)

    eps = 1e-8
    try:
      cov_inv_now = np.linalg.inv(cov_now)
    except:
       cov_inv_now = np.linalg.pinv(cov_now)
    class_cov_invs.append(cov_inv_now)
    class_means.append(mean_now)

  #the average covariance for class specific
  #deal the case of empty class
  class_covs_new = []
  for c in range(len(class_covs)):
    if (True in np.isnan(class_covs[c])):
      continue
    else:
      class_covs_new.append(class_covs[c])
    
  class_cov_invs = [np.linalg.inv(np.mean(np.stack(class_covs_new,axis=0),axis=0))]*len(class_covs)

  maha_intermediate_dict["class_cov_invs"] = class_cov_invs
  maha_intermediate_dict["class_means"] = class_means
  maha_intermediate_dict["cov_inv"] = cov_inv
  maha_intermediate_dict["mean"] = mean


  return maha_intermediate_dict


'''
calculate maha-distance and relative-maha-distance with ground-truth labels
'''
def get_maha_distance(train_logits, class_cov_invs, class_means, targets, norm_name = "L2"):

  scores = np.array([maha_distance(train_logits[i].reshape([1,-1]),class_cov_invs[targets[i]],class_means[targets[i]],norm_name) for i in range(len(targets))])

  return scores


def get_maha_distance_cov(train_logits, cov_invs, class_means, targets, norm_name = "L2"):
  # out_totrainclasses = [maha_distance(train_logits,class_cov_invs[c],class_means[c],norm_name) for c in range(indist_classes)]
  # out_scores = np.min(np.stack(out_totrainclasses,axis=0),axis=0)
  scores = np.array([maha_distance(train_logits[i].reshape([1,-1]),cov_invs,class_means[targets[i]],norm_name) for i in range(len(targets))])

  return scores

def get_relative_maha_distance(train_logits, cov_invs, class_cov_invs, means, class_means, targets, norm_name = "L2"):
    maha_0 = np.array([maha_distance(train_logits[i],cov_invs,means,norm_name) for i in range(len(targets))])
    maha_k = np.array([maha_distance(train_logits[i].reshape([1,-1]),class_cov_invs[targets[i]],class_means[targets[i]],norm_name) for i in range(len(targets))])
    # print(maha_0.shape,maha_k.shape)
    scores = maha_k - maha_0

    return scores

'''
calculate maha-distance and relative-maha-distance without ground-truth labels on test data
'''
def get_maha_predict(train_logits, class_cov_invs, class_means, num_classes, norm_name = "L2"):
    out_totrainclasses = [maha_distance(train_logits,class_cov_invs[c],class_means[c],norm_name) for c in range(num_classes)]
    out_scores = np.argmin(np.stack(out_totrainclasses,axis=0),axis=0)
    # print(out_scores)
    return out_scores

def get_relative_maha_predict(train_logits,cov_invs, class_cov_invs, means, class_means, num_classes, norm_name = "L2"):
    maha_0 = np.array([maha_distance(train_logits[i],cov_invs,means,norm_name) for i in range(len(train_logits))]).reshape(1,-1)
    # print(train_logits.shape)

    out_totrainclasses = np.array([maha_distance(train_logits,class_cov_invs[c],class_means[c],norm_name) for c in range(num_classes)])
    # print(out_totrainclasses.shape,maha_0.shape)
    out_totrainclasses = out_totrainclasses - maha_0
    out_scores = np.argmin(out_totrainclasses,axis=0)
    
    return out_totrainclasses



  
def get_maha_distance_feature(feature, cov_invs, means):
    x_minus_mu = feature - means
    second_powers = np.dot(x_minus_mu,cov_invs) * x_minus_mu
    # mahal = np.dot(left_term, x_minus_mu.T)
    mahal = np.sum(second_powers,axis=1)

    return mahal