import torch
import numpy as np
from typing import Dict

from ood_detectors.interface import OODDetector
from ood_detectors.assets import knn_score

normalizer = lambda x: x / (np.linalg.norm(x, ord=2, axis=-1, keepdims=True) + 1e-10)
prepos_feat = lambda x: np.ascontiguousarray(normalizer(x[:, range(448, 960)]))# Last Layer only

class KNNOODDetectorPlus(OODDetector):
    def setup(self, args, train_model_outputs):
        feas_train = train_model_outputs['feas']
        try:
            self.knn_k = args.detector['knn_k']
        except:
            self.knn_k = 50

        # Apply your functions to feas_train
        self.feas_train = prepos_feat(feas_train)

    def infer(self, model_outputs: Dict):

        feas = model_outputs['feas']

        scores = knn_score(self.feas_train, prepos_feat(feas), k=self.knn_k, min=True)
        scores = torch.from_numpy(scores).to(feas.device)
        return scores

