import os
import sys
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from pathlib import Path

sys.path.append(str(Path(__file__).parent.parent))

from .style_detectors import VanGoghStyleDetector, PicassoStyleDetector
from .category_clip_detectors import CategoryCLIPDetector


class AttackDetector(ABC):
    
    @abstractmethod
    def __init__(self, **kwargs):
        pass
    
    @abstractmethod
    def detect(self, image_path: str) -> Dict[str, Any]:
        pass
    
    @abstractmethod
    def get_reward_score(self, image_path: str) -> float:
        pass
    
    @abstractmethod
    def is_available(self) -> bool:
        pass
    
    def should_early_stop(self, image_path: str, threshold: float = 0.6) -> bool:
        result = self.detect(image_path)
        return result.get('is_detected', False)


class NudeDetector(AttackDetector):
    
    def __init__(self, model_path: Optional[str] = None, **kwargs):
        from asr_evaluation.nude_detector import NudeDetector as BaseNudeDetector
        self.detector = BaseNudeDetector(model_path)
    
    def detect(self, image_path: str) -> Dict[str, Any]:        
        nude_scores = self.detector.get_nude_score(image_path)
        is_nude = self.detector.is_nude(image_path)
        
        overall_score = max(nude_scores.values()) if nude_scores else 0.0
        
        return {
            'score': overall_score,
            'is_detected': is_nude,
            'details': {
                'nude_scores': nude_scores,
                'detector_type': 'nudenet'
            }
        }
    
    def get_reward_score(self, image_path: str) -> float:
        result = self.detect(image_path)
        return result.get('score', 0.0)
    
    def is_available(self) -> bool:
        return not getattr(self.detector, 'simulation_mode', True)


class ViolenceDetector(AttackDetector):
    def __init__(self, device: str = "cuda", prompts_path: Optional[str] = None, **kwargs):
        from q16_evaluation.q16_detector import Q16Detector
        self.detector = Q16Detector(device=device, prompts_path=prompts_path)
        self.device = device
    
    def detect(self, image_path: str) -> Dict[str, Any]:        
        q16_scores = self.detector.get_detection_score(image_path)
        harmful_score = q16_scores.get('harmful_score', 0.0)
        is_harmful = q16_scores.get('is_harmful', False)
        label = q16_scores.get('label', 'unknown')
        
        return {
            'score': harmful_score,
            'is_detected': is_harmful,
            'details': {
                'q16_scores': q16_scores,
                'label': label,
                'detector_type': 'q16'
            }
        }
    
    def get_reward_score(self, image_path: str) -> float:
        result = self.detect(image_path)
        return result.get('score', 0.0)
    
    def is_available(self) -> bool:
        return self.detector.is_available()


class VanGoghDetector(AttackDetector):
    
    def __init__(self, classifier_path: str, device: str = "cuda", **kwargs):
        self.style_detector = VanGoghStyleDetector(classifier_path, device)
        self.device = device
    
    def detect(self, image_path: str) -> Dict[str, Any]:        
        style_results = self.style_detector.get_detailed_results(image_path)
        
        vangogh_score = style_results.get('vangogh_score', 0.0)
        is_top3 = style_results.get('is_vangogh_top3', False)
        
        return {
            'score': vangogh_score,
            'is_detected': is_top3,
            'details': style_results
        }
    
    def get_reward_score(self, image_path: str) -> float:
        result = self.detect(image_path)
        return result.get('score', 0.0)
    
    def is_available(self) -> bool:
        return self.style_detector.is_available()


class PicassoDetector(AttackDetector):
    
    def __init__(self, classifier_path: str, device: str = "cuda", **kwargs):
        self.style_detector = PicassoStyleDetector(classifier_path, device)
        self.device = device
    
    def detect(self, image_path: str) -> Dict[str, Any]:        
        style_results = self.style_detector.get_detailed_results(image_path)
        
        picasso_score = style_results.get('picasso_score', 0.0)
        is_top3 = style_results.get('is_picasso_top3', False)
        
        return {
            'score': picasso_score,
            'is_detected': is_top3,
            'details': style_results
        }
    
    def get_reward_score(self, image_path: str) -> float:
        result = self.detect(image_path)
        return result.get('score', 0.0)
    
    def is_available(self) -> bool:
        return self.style_detector.is_available()


class DetectorFactory:
    
    _detectors = {
        'nude': NudeDetector,
        'violence': ViolenceDetector,
        'vangogh': VanGoghDetector,
        'pablo_picasso': PicassoDetector,
        'category': CategoryCLIPDetector, 
    }
    
    @classmethod
    def create_detector(cls, attack_type: str, **kwargs) -> AttackDetector:        
        detector_class = cls._detectors[attack_type]
        return detector_class(**kwargs)
    
    @classmethod
    def get_supported_types(cls) -> List[str]:
        return list(cls._detectors.keys())
    
    @classmethod
    def register_detector(cls, attack_type: str, detector_class: type):
        cls._detectors[attack_type] = detector_class

__all__ = [
    'AttackDetector', 
    'NudeDetector', 
    'ViolenceDetector', 
    'VanGoghDetector',
    'PicassoDetector',
    'CLIPDetector',
    'CategoryCLIPDetector',
    'DetectorFactory'
] 