import faiss
from sklearn.preprocessing import normalize
from OOD_score_utils import log_sum_exponential_score
from knn_search_GPU import knn_sim_IP_GPU, estimate_dense_k
from common_imports import np, tqdm
from common_use_functions import softmax

"""
The class for the Mahalanobis score.
"""
class Mahalanobis(object):
    def __init__(self, normalize_on=True, standardize_on=True, num_clusters=5):

        self.normalize_on = normalize_on
        self.standardize_on = standardize_on
        self.num_clusters = num_clusters  # the number of K-means clusters

    def fit(self, X, y=None):

        if y is None:
            supervised = False
        else:
            supervised = True

        X = np.array(X)
        y = np.array(y)
        dim = X.shape[1]

        self.mean = np.mean(X, axis=0, keepdims=True)
        self.std = np.std(X, axis=0, keepdims=True)

        X = self._preprocess(X)

        # clustering
        if supervised:
            self.num_clusters = len(np.unique(y))
        else:
            if self.num_clusters > 1:
                kmeans = faiss.Kmeans(d=X.shape[1], k=self.num_clusters, niter=100, verbose=False, gpu=False)
                kmeans.train(np.array(X))
                y = np.array(kmeans.assign(X)[1])
            else:
                y = np.zeros(len(X))

        self.center = np.zeros(shape=(self.num_clusters, dim))
        cov = np.zeros(shape=(self.num_clusters, dim, dim))

        for k in tqdm(range(self.num_clusters)):
            X_k = np.array(X[y == k])

            self.center[k] = np.mean(X_k, axis=0)
            cov[k] = np.cov(X_k.T, bias=True)

        if supervised:
            shared_cov = cov.mean(axis=0)
            self.shared_icov = np.linalg.pinv(shared_cov)
        else:
            self.icov = np.zeros(shape=(self.num_clusters, dim, dim))
            self.shared_icov = None
            for k in tqdm(range(self.num_clusters)):
                self.icov[k] = np.linalg.pinv(cov[k])

    def score(self, X, return_distance=False):
        X = np.array(X)
        X = self._preprocess(X)

        if self.shared_icov is not None:
            M = self.shared_icov
            U = self.center
            md = (np.matmul(X, M) * X).sum(axis=1)[:, None] \
                 + ((np.matmul(U, M) * U).sum(axis=1).T)[None, :] \
                 - 2 * np.matmul(np.matmul(X, M), U.T)
        else:
            md = []
            for k in tqdm(range(self.num_clusters)):
                delta_k = X - self.center[k][None, :]
                md.append((np.matmul(delta_k, self.icov[k]) * delta_k).sum(axis=1))
            md = np.array(md).T

        out = md.min(axis=1)

        if return_distance:
            return out

        return np.exp(-(out/2048) / 2)
        # return np.exp(-out / 2)

    def _preprocess(self, X):
        if self.normalize_on:
            X = normalize(X, axis=1)    # normalize

        if self.standardize_on:
            X = (X - self.mean) / (self.std + 1e-8)     # standardize

        return X

    def _mahalanobis_score(self, x, center, icov):
        delta = x - center
        ms = (np.matmul(delta, icov) * delta).sum(axis=1)
        return np.maximum(ms, 0)
    
def get_mahalanobis_model(train_feats, train_labels):
    mahalanobis = Mahalanobis()
    mahalanobis.fit(train_feats, train_labels)
    
    return mahalanobis
    
def get_mahalanobis_score(input_feats, mahalanobis_model):
    return mahalanobis_model.score(input_feats)


"""
The class for the NNguide score
"""
class NNGuide(object):

    def __init__(self, train_logits, train_zs, train_labels, k_max, knn_batch_size=50, half_precision=False):
        # Assign values
        self.train_logits = train_logits
        self.train_zs = train_zs
        self.train_labels = train_labels
        self.knn_batch_size = knn_batch_size
        self.half_precision = half_precision
        # Determine the k value
        train_sims,_ = knn_sim_IP_GPU(self.train_zs, self.train_zs, batch_size=self.knn_batch_size,
                                                       k=k_max, display=False, half_precision=self.half_precision)
        self.knn_k = estimate_dense_k(train_sims, verify_steps=3, variation_threshold=0.1, min_k=5, smooth_sigma=0)
        # Create the scaled zs
        confs_train = log_sum_exponential_score(self.train_logits, sum_axis=1).reshape(-1,1)
        self.scaled_train_zs = confs_train * train_zs

    def score(self, zs, logits):
        confs =  log_sum_exponential_score(logits, sum_axis=1)
        sims,_ = knn_sim_IP_GPU(self.scaled_train_zs, zs, batch_size=self.knn_batch_size,
                                                       k=self.knn_k, display=False, half_precision=self.half_precision)
        guidances = np.mean(sims[:, :self.knn_k], axis=1)
        scores = guidances*confs
        return scores
    
def get_nnguide_model(train_logits, train_zs, train_labels, k_max, knn_batch_size=50, half_precision=False):
    nnguide = NNGuide(train_logits, train_zs, train_labels, k_max, knn_batch_size=knn_batch_size, half_precision=half_precision)
    
    return nnguide
    
def get_nnguide_score(input_zs, input_logits, nnguide_model):
    return nnguide_model.score(input_zs, input_logits)

"""
The score functions
"""
def get_msp_score(logits):
    probs = softmax(logits)
    scores = np.max(probs, axis=1)

    return scores

def get_maxlogit_score(logits):
    scores = np.max(logits, axis=1)

    return scores

def get_kl_score(logits):
    probs = softmax(logits)
    num_classes = logits.shape[-1]
    uniform_target = np.ones_like(logits) / num_classes
    scores = -np.sum(uniform_target*np.log(probs+1e-12), axis=1)

    return scores

def get_energy_score(logits):
    return log_sum_exponential_score(logits, sum_axis=1)

def get_score(logits, method):
    if method == "msp":
        scores = get_msp_score(logits)
    elif method == "maxlogit":
        scores = get_maxlogit_score(logits)
    elif method == "kl":
        scores = get_kl_score(logits)
    elif method == "energy":
        scores = get_energy_score(logits)
    else:
        print("score method error")
    return scores