                                      
import sys                     
import os                    
from typing import Dict, Any, List, Optional, Tuple
import logging                    

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
import time
from fortress.common.data_models import InputPromptRecord, QueryFeatures, DetectionResult, FinalDetectionOutput, DecisionLabel
from fortress.data_management.prompt_processor import PromptProcessor
from fortress.detection_pipeline.primary_detector import PrimaryDetector
from fortress.detection_pipeline.secondary_analyzer import SecondaryAnalyzer                                           
from fortress.config import get_config                                  

logger = logging.getLogger(__name__)

class DetectionPipeline:
    """
    Orchestrates the full detection flow from query processing to final decision,
    integrating vector search (PrimaryDetector) and perplexity analysis (SecondaryAnalyzer).
    """
    def __init__(self, 
                 prompt_processor: PromptProcessor, 
                 primary_detector: PrimaryDetector, 
                 secondary_analyzer: SecondaryAnalyzer):
        self.prompt_processor = prompt_processor
        self.primary_detector = primary_detector
        self.secondary_analyzer = secondary_analyzer
        self.config = get_config()              

                                                      
        pipeline_config = self.config.get('detection_pipeline', {})
        self.ensemble_strategy = pipeline_config.get('ensemble_strategy', 'vector_dominant')
        self.thresholds = {
            'perplexity_dominant_unsafe_threshold': pipeline_config.get('perplexity_dominant_unsafe_threshold', 0.75),
            'vector_dominant_safe_distance_threshold': pipeline_config.get('vector_dominant_safe_distance_threshold', 0.1),
            'vector_dominant_safe_perplexity_threshold': pipeline_config.get('vector_dominant_safe_perplexity_threshold', 0.25)
        }
                                                              
        self.weighted_majority_mixed_label_ratio_threshold = pipeline_config.get("weighted_majority_mixed_label_ratio_threshold", 0.3)
        self.weighted_majority_vote_weights = pipeline_config.get(
            "weighted_majority_vote_weights",
            {"default_primary": 0.6, "default_perplexity": 0.4, "mixed_primary": 0.3, "mixed_perplexity": 0.7}
        )

        logger.info(f"DetectionPipeline initialized with ensemble strategy: {self.ensemble_strategy}, parameters: {self.thresholds}")

    def _apply_ensemble_strategy(
        self,
        primary_similar_docs: List[Dict[str, Any]],
        query_features: QueryFeatures,
        perplexity_detection_result: DetectionResult
    ) -> Tuple[DecisionLabel, float, str]:
        
        pipeline_config = get_config().get("detection_pipeline", {})
        strategy = pipeline_config.get("ensemble_strategy", "vector_dominant") 

        logger.info(f"Applying ensemble strategy: '{strategy}'")

        if not primary_similar_docs:
            logger.warning("No similar documents found by primary detector. Falling back to perplexity analysis.")
            return (
                perplexity_detection_result.decision,
                perplexity_detection_result.confidence,
                f"No primary results. Fallback: {perplexity_detection_result.explanation}"
            )

        if strategy == "weighted_majority_vote":
            safe_weight = 0.0
            unsafe_weight = 0.0
            num_safe_docs = 0
            num_unsafe_docs = 0
            
            for doc in primary_similar_docs:
                distance = doc.get("distance", 1.0)
                                                                                   
                                                                               
                weight = max(0, 1.0 - distance) 

                label_str = doc.get("metadata", {}).get("label_str")
                if label_str is None:                                                        
                    label_int = doc.get("metadata", {}).get("label")
                    if label_int == 0:                     
                        label_str = DecisionLabel.SAFE.value
                    elif label_int == 1:                       
                        label_str = DecisionLabel.UNSAFE.value
                    else:
                        logger.debug(f"Document ID {doc.get('id')} has unclear label (neither 0 nor 1, nor string label). Skipping in weighted vote.")
                        continue                                

                if label_str == DecisionLabel.SAFE.value:
                    safe_weight += weight
                    num_safe_docs += 1
                elif label_str == DecisionLabel.UNSAFE.value:
                    unsafe_weight += weight
                    num_unsafe_docs +=1
            
            total_docs_considered = num_safe_docs + num_unsafe_docs
            if total_docs_considered == 0:
                logger.warning("Weighted majority vote: No documents with clear labels found. Falling back to perplexity.")
                return (
                    perplexity_detection_result.decision,
                    perplexity_detection_result.confidence,
                    f"Weighted majority vote no clear labels. Fallback: {perplexity_detection_result.explanation}"
                )

                                            
            minority_label_ratio = 0.0
            if total_docs_considered > 0:
                minority_label_ratio = min(num_safe_docs, num_unsafe_docs) / total_docs_considered
            
            is_mixed_results = minority_label_ratio >= self.weighted_majority_mixed_label_ratio_threshold
            logger.info(f"Minority label ratio: {minority_label_ratio:.2f}, Mixed results threshold: {self.weighted_majority_mixed_label_ratio_threshold:.2f}, Is mixed: {is_mixed_results}")

                                                               
            current_weights = self.weighted_majority_vote_weights
            if is_mixed_results:
                primary_weight_factor = current_weights.get("mixed_primary", 0.3)
                perplexity_weight_factor = current_weights.get("mixed_perplexity", 0.7)
                logger.info(f"Using MIXED weights: Primary={primary_weight_factor}, Perplexity={perplexity_weight_factor}")
            else:
                primary_weight_factor = current_weights.get("default_primary", 0.6)
                perplexity_weight_factor = current_weights.get("default_perplexity", 0.4)
                logger.info(f"Using DEFAULT weights: Primary={primary_weight_factor}, Perplexity={perplexity_weight_factor}")

                                                                      
            primary_safe_score = safe_weight * primary_weight_factor
            primary_unsafe_score = unsafe_weight * primary_weight_factor
            
                                                                         
                                                                                                    
            perplexity_safe_score = 0.0
            perplexity_unsafe_score = 0.0

            if perplexity_detection_result.decision == DecisionLabel.SAFE:
                perplexity_safe_score = perplexity_detection_result.confidence * perplexity_weight_factor
            elif perplexity_detection_result.decision == DecisionLabel.UNSAFE:
                perplexity_unsafe_score = perplexity_detection_result.confidence * perplexity_weight_factor
                                                                                                      

            final_safe_score = primary_safe_score + perplexity_safe_score
            final_unsafe_score = primary_unsafe_score + perplexity_unsafe_score
            
            total_combined_score = final_safe_score + final_unsafe_score

            if total_combined_score == 0:                                                        
                logger.warning("Weighted majority vote: Total combined score is zero. Defaulting to AMBIGUOUS or perplexity if decisive.")
                if perplexity_detection_result.decision != DecisionLabel.AMBIGUOUS:
                     return (
                        perplexity_detection_result.decision,
                        perplexity_detection_result.confidence,
                        f"Weighted majority vote total score zero. Fallback to perplexity: {perplexity_detection_result.explanation}"
                    )
                return (
                    DecisionLabel.AMBIGUOUS, 
                    0.5, 
                    "Weighted majority vote: All scores zero, resulting in ambiguous."
                )

            if final_safe_score > final_unsafe_score:
                final_label = DecisionLabel.SAFE
                confidence = final_safe_score / total_combined_score
            elif final_unsafe_score > final_safe_score:
                final_label = DecisionLabel.UNSAFE
                confidence = final_unsafe_score / total_combined_score
            else:      
                logger.info("Weighted majority vote (combined) resulted in a tie. Defaulting to perplexity analysis decision.")
                                                                             
                return (
                    perplexity_detection_result.decision,
                    perplexity_detection_result.confidence,
                    f"Weighted majority vote (combined) tie. Fallback: {perplexity_detection_result.explanation}"
                )
            
            justification = (
                f"Weighted vote ({'mixed' if is_mixed_results else 'default'} weights): {final_label.value}. "
                f"Primary (S:{primary_safe_score:.2f}, U:{primary_unsafe_score:.2f}), "
                f"Perplexity (S:{perplexity_safe_score:.2f}, U:{perplexity_unsafe_score:.2f}). "
                f"Final (S:{final_safe_score:.2f}, U:{final_unsafe_score:.2f})"
            )
            logger.info(f"Weighted majority vote result: {final_label.value}, Confidence: {confidence:.4f}. Justification: {justification}")
            return final_label, confidence, justification

        elif strategy == "vector_dominant":
            logger.info("Applying 'vector_dominant' ensemble strategy.")
            vd_safe_dist_thresh = pipeline_config.get("vector_dominant_safe_distance_threshold", 0.1)
            vd_safe_perp_thresh = pipeline_config.get("vector_dominant_safe_perplexity_threshold", 0.25)

            closest_doc = primary_similar_docs[0]
            closest_doc_label_str = closest_doc.get("metadata", {}).get("label_str")
            closest_doc_distance = closest_doc.get("distance", 1.0)

            if (closest_doc_label_str == DecisionLabel.SAFE.value and 
                closest_doc_distance < vd_safe_dist_thresh and 
                (perplexity_detection_result.sentence_adversarial_probability is not None and
                 perplexity_detection_result.sentence_adversarial_probability < vd_safe_perp_thresh)):
                final_label = DecisionLabel.SAFE
                confidence = 1.0 - closest_doc_distance
                justification = (
                    f"Vector dominant: Closest SAFE doc ({closest_doc.get('id')}) very similar (dist: {closest_doc_distance:.3f}) "
                    f"and query perplexity ({perplexity_detection_result.sentence_adversarial_probability:.3f}) low."
                )
                logger.info(f"Vector dominant result: {final_label.value}, Confidence: {confidence:.4f}")
                return final_label, confidence, justification
            else:
                logger.info("Vector dominant conditions not met for a direct decision. Falling back to perplexity.")
                return (
                    perplexity_detection_result.decision,
                    perplexity_detection_result.confidence,
                    f"Vector dominant conditions not met. Fallback: {perplexity_detection_result.explanation}"
                )
        
        elif strategy == "perplexity_dominant":
            logger.info("Applying 'perplexity_dominant' ensemble strategy.")
                                                                                        
                                                                                         
            final_label = perplexity_detection_result.decision
            confidence = perplexity_detection_result.confidence
            justification = (
                f"Perplexity dominant: Decision based on perplexity analysis. "
                f"({final_label.value}, Conf: {confidence:.3f}). Explanation: {perplexity_detection_result.explanation}"
            )
            logger.info(f"Perplexity dominant result: {final_label.value}, Confidence: {confidence:.4f}")
            return final_label, confidence, justification

        logger.warning(
            f"Unknown or unhandled ensemble strategy: '{strategy}'. Defaulting to perplexity result."
        )
        return (
            perplexity_detection_result.decision,
            perplexity_detection_result.confidence,
            f"Unknown strategy '{strategy}'. Fallback: {perplexity_detection_result.explanation}"
        )

    def run_batch(self, prompt_records: List[InputPromptRecord]) -> List[FinalDetectionOutput]:
        logger.debug(f"DetectionPipeline running for batch of {len(prompt_records)} prompts.")
        batch_query_texts = [pr.original_prompt for pr in prompt_records]
        final_outputs: List[FinalDetectionOutput] = []

                                                               
                                                                               
                                                                      
                                                                                                       
                                                                                               
        
                                                                                   
                            
                                                                                           
        primary_detector_results = self.primary_detector.detect_batch(batch_query_texts)

        for i, (query_features_from_primary, _, primary_similar_docs) in enumerate(primary_detector_results):
            query_text = batch_query_texts[i]                                  
                                                  
            final_decision = DecisionLabel.ERROR
            overall_confidence = 0.0
            is_ambiguous = True
            justification = "Error during processing single item in batch."
            detection_stages_summary = {}
            error_info = None
                                                                                                  

            if query_features_from_primary is None:
                                                                  
                final_decision = DecisionLabel.ERROR
                justification = "Primary detector did not return valid query features for this item."
                                               
            else:
                                                                                
                if query_features_from_primary.token_source_log_probabilities is None:
                    perplexity_result = DetectionResult(...)                  
                else:
                    perplexity_result = self.secondary_analyzer.perform_perplexity_analysis(
                        query_features=query_features_from_primary
                    )
                detection_stages_summary["secondary_analyzer_perplexity_result"] = perplexity_result.model_dump()

                                                                  
                final_decision, overall_confidence, justification = self._apply_ensemble_strategy(
                    primary_similar_docs=primary_similar_docs,
                    query_features=query_features_from_primary,
                    perplexity_detection_result=perplexity_result
                )
                is_ambiguous = final_decision == DecisionLabel.AMBIGUOUS
                                                               
            
                                                          
            output_item = FinalDetectionOutput(
                query_text=query_text,
                final_decision=final_decision,
                overall_confidence=overall_confidence,
                is_ambiguous=is_ambiguous,
                justification=justification,
                detection_stages_summary=detection_stages_summary,
                error_info=error_info,
                query_features=query_features_from_primary,
                primary_detector_top_k_results=primary_similar_docs,
                     
            )
            final_outputs.append(output_item)
            
        return final_outputs
    def run(self, query_text: str) -> FinalDetectionOutput:
        """
        Processes a query string through the full detection pipeline.
        """
        logger.debug(f"DetectionPipeline running for query: {query_text[:5]}...")
        start_time = time.time()

        detection_stages_summary = {}
        error_info = None
        final_decision = DecisionLabel.ERROR                   
        overall_confidence = 0.0
        is_ambiguous = False          
        justification = "Error during processing."
        primary_detector_top_k_results = None
        primary_ensemble_strategy_used = None
        primary_ensemble_confidence = None

        try:
            start_time = time.time()

                                                                             
                                          
                                                                                                          
                                  
            query_features_from_primary, _, primary_similar_docs = self.primary_detector.detect(query_text)

            if query_features_from_primary is None:
                logger.error("Primary detector failed to return query features.")
                final_decision = DecisionLabel.ERROR
                justification = "Primary detector did not return valid query features."
                overall_confidence = 0.0
                is_ambiguous = True                                      
            else:
                logger.info(f"Primary detector found {len(primary_similar_docs)} similar documents.")
                detection_stages_summary["primary_detector_retrieved_count"] = len(primary_similar_docs)
                primary_detector_top_k_results = primary_similar_docs                   

                                                          
                                                                                              
                                                        
                                                                                             
                                                                    
                if query_features_from_primary.token_source_log_probabilities is None:
                    logger.warning("Skipping perplexity analysis: token_source_log_probabilities are missing.")
                    perplexity_result = DetectionResult(
                        decision=DecisionLabel.AMBIGUOUS,                                          
                        confidence=0.0,
                        explanation="Perplexity analysis skipped due to missing token log probabilities.",
                        perplexity_analysis_details="Missing token_source_log_probabilities."
                    )
                else:
                    perplexity_result = self.secondary_analyzer.perform_perplexity_analysis(
                        query_features=query_features_from_primary                             
                    )
                logger.info(f"Secondary analyzer (perplexity) result: {perplexity_result.decision.value} with confidence {perplexity_result.confidence:.4f}")
                detection_stages_summary["secondary_analyzer_perplexity_result"] = perplexity_result.model_dump()

                                            
                                                                                                   
                                                        
                                                                                    
                                                                                       
                                            
                
                                                                                                    
                                                                           
                                                                               
                                                                         
                
                                                                          
                                                                                                
                                                                  
                
                                                                                           
                                                                 
                                                                                                    
                                                                
                                                                                                          
                                                                   
                   
                                                    
                final_decision, overall_confidence, justification = self._apply_ensemble_strategy(
                    primary_similar_docs=primary_similar_docs,
                    query_features=query_features_from_primary,                                         
                    perplexity_detection_result=perplexity_result
                )

                                                                 
                if final_decision == DecisionLabel.AMBIGUOUS:
                    is_ambiguous = True
                else:
                    is_ambiguous = False                                           

                                                                                
                                                                                      
                                                                       
                                                                
                primary_ensemble_strategy_used = self.ensemble_strategy
                                                                                 
                                                                                     
                                                                                                  
                                                                                              

                detection_stages_summary["ensemble_strategy_applied"] = primary_ensemble_strategy_used
                                                                                              
                                                                        
                                                                                                    
                                                                                                   
                                                                                                        
                if "Perplexity" in justification:                     
                    detection_stages_summary["final_decision_source"] = "PerplexityAnalyzer"
                elif "Vector search" in justification:
                    detection_stages_summary["final_decision_source"] = "PrimaryDetector"
                else:
                    detection_stages_summary["final_decision_source"] = "EnsembleLogic"


        except Exception as e:
            logger.exception("Error in DetectionPipeline run:")
            error_info = str(e)
            justification = f"Exception occurred: {error_info}"
            is_ambiguous = True                          
            final_decision = DecisionLabel.ERROR
            overall_confidence = 0.0

        end_time = time.time()
        processing_time_ms = (end_time - start_time) * 1000
        logger.info(f"DetectionPipeline finished in {processing_time_ms:.2f} ms.")

                                           
        return FinalDetectionOutput(
            query_text=query_text,
            final_decision=final_decision,
            overall_confidence=overall_confidence,
            is_ambiguous=is_ambiguous,
            justification=justification,
            detection_stages_summary=detection_stages_summary,
            error_info=error_info,
            query_features=query_features_from_primary if 'query_features_from_primary' in locals() else None,
            primary_detector_top_k_results=primary_detector_top_k_results,
            primary_ensemble_strategy_used=primary_ensemble_strategy_used,
            primary_ensemble_confidence=primary_ensemble_confidence
        )
