import tensorflow as tf
    
import detectors
import torch
import torchvision as tv

import numpy as np

from typing import Any

import warnings
warnings.filterwarnings("ignore", category=UserWarning)


DETECTORS_BASELINES = ['max_logits','msp','odin','energy','dice', 'react', 'mahalanobis', 'knn_euclides']
TRAINED_DETECTORS = ['mahalanobis', 'knn_euclides', 'react']

class BaselineDetector:
    method_name: str
    model: Any
    detector: Any

    def __init__(self,method_name,model,detector):
        self.method_name = method_name
        self.model = model
        self.detector = detector
    
    @classmethod
    def create(cls,method_name,model,data_loader,**kwargs):
        if method_name in DETECTORS_BASELINES:
            detector = detectors.create_detector(method_name, model=model, **kwargs)
        else:
            raise ValueError("Unknown baseline detector.")

        if method_name in TRAINED_DETECTORS:
            detector.fit(data_loader)
    
        return cls(method_name,model,detector)

    def __call__(self,x):
        detector_output = self.detector(x).cpu().numpy()
        if len(detector_output.shape) > 1:
            detector_output = np.max(detector_output,axis=1)
        return detector_output
    
        
