"""Gravitational Outlier Factor (GOF)
"""

import torch
import torch.nn.functional as F
import torch.nn as nn

import numpy as np

def get_dcfod_score(feat, means, device='cpu'):
    feat = torch.tensor(feat).to(device)

    # -----obtain all instances' distance to cluster centroids----- #
    xe = torch.unsqueeze(feat, 1) - means
    dist_to_centers = torch.sum(torch.mul(xe, xe), 2)
    dist = torch.sqrt(dist_to_centers)
    
    outlier_score, position = torch.min(dist, dim=1)
    for i in range(dist.shape[1]):
        pos = list(x for x in range(len(outlier_score)) if position[x] == i)
        if len(outlier_score[pos]) != 0:
            max_dist = max(outlier_score[pos])
            outlier_score[pos] = torch.div(outlier_score[pos], max_dist).to(device)
    
    return outlier_score.cpu().detach().numpy()


class GOFScorer():
    def __init__(self, device='cpu'):
        self.device = device
        
    def get_score(self, feat, means, covars=None, weights=None, score_type='vector'):
        if covars is None:
            cv = torch.eye(feat.shape[1])
            covars = np.tile(np.diag(cv), (means.shape[0], 1))
            covars = torch.from_numpy(covars).to(self.device)
            
        maha = self._mahalanobis_distance_mix_diag(feat, means, covars)

        if weights is None:
            _, indices = torch.min(maha, dim=1)
            weights = F.one_hot(indices.long(), num_classes=means.shape[0]).float().sum(dim=0) / feat.shape[0]
        weights = weights.reshape(1, -1).to(self.device)
        
        det_covars = torch.from_numpy(np.prod(covars.cpu().numpy(), axis=1)).to(self.device)
        braces = 1 + maha
        demon = torch.pi * torch.sqrt(det_covars) * braces
        forces = weights / demon
        
        delta = means.unsqueeze(0) - feat.unsqueeze(1)
        unit_vec = F.normalize(delta, p=2, dim=-1)
        vec_forces = forces.unsqueeze(2) * unit_vec 
        
        if score_type == 'scalar':
            score = 1. / forces.sum(dim=1)
        elif score_type == 'vector':
            score = 1. / torch.norm(vec_forces.sum(dim=1), dim=-1)
            
        return score.cpu().detach().numpy()
    
    def _mahalanobis_distance_mix_diag(self, X, means, covars):
        n_samples, n_dim = X.shape
        n_components = means.shape[0]
        result = torch.empty((n_samples, n_components)).to(self.device)

        for c in range(n_components):
            mu = means[c]
            cv = covars[c]
            centred_X = X - mu
            inv_cov = 1.0 / cv
            result[:, c] = (centred_X * inv_cov * centred_X).sum(dim=1)

        return result
    