from .baselines import BaselineDetector, DETECTORS_BASELINES
from .gradpca import GradPCA
from .gaia import GAIADetector

from .config import MODEL_LAYER_CONFIG, DETECTOR_KWARGS

from typing import Any

class DetectorsFactory:
    method_name: str
    model_name: str
    detector: Any 
    framework: str

    def __init__(self,method_name,model,detector,framework):
        self.method_name = method_name
        self.model = model
        self.detector = detector
        self.framework = framework

    @classmethod
    def create(cls,method_name,model_name,num_classes,models,data_loader,device):
        if method_name in ['GradPCA','GradOrth','GradPCA-Batch','GradPCA+DICE']:
            kwargs = DETECTOR_KWARGS[method_name]
            detector = GradPCA.create(num_classes=num_classes,state=models['jax'],
                                      data_loader=data_loader,
                                      **kwargs)
            framework = 'jax'

        elif method_name in DETECTORS_BASELINES:
            kwargs = MODEL_LAYER_CONFIG[model_name]
            detector = BaselineDetector.create(method_name=method_name,model=models['torch'],
                                               data_loader=data_loader.set_framework('torch'),
                                               **kwargs)
            framework = 'torch'
        elif method_name in ['GAIA-A', 'GAIA-Z']:
            detector = GAIADetector.create(method_name,model_name,models['torch'],device)
            framework = 'torch'
        else:
            raise ValueError("Unknown detector name.")
        return cls(method_name,model_name,detector,framework)
    
    def __call__(self,x):
        return self.detector(x)
    
    def end(self):
        if self.method_name in ['GAIA-A', 'GAIA-Z']:
            self.detector.end()

    
