                                                    
import math
import logging
from typing import List, Tuple, Dict, Any, Optional

import sys                                                                                            
import os                                           
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))

                                                                                                           
                                                                                                   
                                                                                       

from fortress.common.data_models import QueryFeatures, DetectionResult, DecisionLabel
from fortress.config import get_global_config                               

logger = logging.getLogger(__name__)

                  
def logsumexp(arr: List[float]) -> float:
    if not arr:
                                                                            
                                                               
        return -float('inf')
    
    max_val = max(arr)
    if max_val == -float('inf'):
                                                                  
        return -float('inf')
    
                                                                                  
                                         
    try:
        sum_exp = sum(math.exp(x - max_val) for x in arr)
        if sum_exp == 0:                                                        
            return -float('inf')
        return max_val + math.log(sum_exp)
    except ValueError:                                             
        logger.error("ValueError in logsumexp, likely due to math.log(0).")
        return -float('inf')


class PerplexityAnalyzerEngine:
    """
    Engine for performing perplexity-based analysis using a CRF/HMM-like model.
    Calculates token-level and sentence-level adversarial probabilities.
    """
    def __init__(self, 
                 adversarial_token_uniform_log_prob: float,
                 lambda_smoothness_penalty: float,
                 mu_adversarial_token_prior: float,
                 apply_first_token_neutral_bias: bool):
        self.adversarial_token_uniform_log_prob = adversarial_token_uniform_log_prob
        self.lambda_smoothness_penalty = lambda_smoothness_penalty
        self.mu_adversarial_token_prior = mu_adversarial_token_prior
        self.apply_first_token_neutral_bias = apply_first_token_neutral_bias
        logger.info(f"PerplexityAnalyzerEngine initialized with: uniform_log_prob={adversarial_token_uniform_log_prob}, "
                    f"lambda={lambda_smoothness_penalty}, mu={mu_adversarial_token_prior}, "
                    f"bias={apply_first_token_neutral_bias}")

    def calculate_adversarial_probabilities(self, token_source_log_probs: List[float]) -> Tuple[List[float], float]:
        """
        Calculates token and sentence adversarial probabilities.

        Args:
            token_source_log_probs: List of log p_LLM(x_i|x_1,...,x_{i-1}) for each token.

        Returns:
            A tuple containing:
                - List[float]: p(c_i=1|vec{x}) for each token.
                - float: 1 - p(all c_i=0|vec{x}) for the sentence.
        """
        N = len(token_source_log_probs)
        if N == 0:
            logger.warning("Empty token_source_log_probs received. Cannot perform perplexity analysis.")
            return [], 0.0

                                                      
        log_phi_emission = [[0.0, 0.0] for _ in range(N)]                     
        log_phi_transition = [[0.0, 0.0], [0.0, 0.0]]                              
        log_phi_prior = [0.0, 0.0]                                     

                                 
        for t in range(N):
            log_phi_emission[t][0] = token_source_log_probs[t]
            log_phi_emission[t][1] = self.adversarial_token_uniform_log_prob
            if self.apply_first_token_neutral_bias and t == 0:
                log_phi_emission[0][0] = 0.0                               
                log_phi_emission[0][1] = 0.0

                                   
        log_phi_transition[0][0] = 0.0                            
        log_phi_transition[0][1] = -self.lambda_smoothness_penalty                                 
        log_phi_transition[1][0] = -self.lambda_smoothness_penalty                                 
        log_phi_transition[1][1] = 0.0                                      

                              
        log_phi_prior[0] = 0.0                         
        log_phi_prior[1] = -self.mu_adversarial_token_prior                                                       

                                  
        log_alpha = [[-float('inf'), -float('inf')] for _ in range(N)]

                              
        log_alpha[0][0] = log_phi_emission[0][0] + log_phi_prior[0]
        log_alpha[0][1] = log_phi_emission[0][1] + log_phi_prior[1]

                                  
        for t in range(1, N):
            for j in range(2):                         
                    log_alpha[t][j] = (log_phi_emission[t][j] + log_phi_prior[j] + 
                    logsumexp([log_alpha[t-1][k] + log_phi_transition[k][j] for k in range(2)]))

        
                                  
        log_beta = [[-float('inf'), -float('inf')] for _ in range(N)]

                                
        log_beta[N-1][0] = 0.0         
        log_beta[N-1][1] = 0.0         

                                       
        for t in range(N-2, -1, -1):
            for i in range(2):                         
                log_beta[t][i] = logsumexp([log_phi_transition[i][k] + log_phi_emission[t+1][k] + log_phi_prior[k] + log_beta[t+1][k] for k in range(2)])

                                            
        log_Z = logsumexp([log_alpha[N-1][j] for j in range(2)])
        if log_Z == -float('inf'):                                                         
            logger.warning("Log-likelihood of sequence is -inf. Probabilities will be ill-defined.")
                                                                                                 
                                                                   
            return [0.0] * N, 1.0                                                                    

                                                      
        token_adversarial_probabilities = [0.0] * N
        for t in range(N):
            log_p_ct_1 = log_alpha[t][1] + log_beta[t][1] - log_Z
            if log_p_ct_1 == -float('inf'):                                                         
                 token_adversarial_probabilities[t] = 0.0
            else:
                try:
                    token_adversarial_probabilities[t] = math.exp(log_p_ct_1)
                except OverflowError:                              
                    logger.error(f"Overflow calculating token adversarial probability for token {t}")
                    token_adversarial_probabilities[t] = 1.0                          

                                                                    
                                                                      
                                               
                                                                     
        log_alpha_all_zero = [-float('inf')] * N
        if N > 0:
            log_alpha_all_zero[0] = log_phi_emission[0][0] + log_phi_prior[0]                                        
            for t in range(1, N):
                                                                     
                log_alpha_all_zero[t] = (log_phi_emission[t][0] + log_phi_prior[0] + 
                                        log_alpha_all_zero[t-1] + log_phi_transition[0][0])
        
        log_p_all_zero_path_unnormalized = log_alpha_all_zero[N-1] if N > 0 else -float('inf')
        
                                                                         
        log_p_all_zero_conditional = log_p_all_zero_path_unnormalized - log_Z

        if log_p_all_zero_conditional == -float('inf'):
            prob_all_zero = 0.0
        else:
            try:
                prob_all_zero = math.exp(log_p_all_zero_conditional)
            except OverflowError:                                
                 prob_all_zero = 0.0                                    
        
        sentence_adversarial_probability = 1.0 - prob_all_zero

        return token_adversarial_probabilities, sentence_adversarial_probability


class SecondaryAnalyzer:
    """
    Performs secondary analysis on query features using a perplexity-based model.
    Calculates token-level and sentence-level adversarial probabilities.
    Applies category-specific or default thresholds for decision making.
    Now supports weighted majority vote for multiple categories.
    """

    def __init__(self):
        config = get_global_config()                                   

        self.default_engine_params = config.default_perplexity_engine_settings
        if not self.default_engine_params:
            logger.error("Default perplexity engine settings not found in config.")
                                                                                                     
            self.default_engine_params = {
                "model_for_log_probs": "google/gemma-3-1b-it",           
                "adversarial_token_uniform_log_prob": -5.0,
                "lambda_smoothness_penalty": 2.5,
                "mu_adversarial_token_prior": -2.0,
                "apply_first_token_neutral_bias": False,
                "sentence_adversarial_probability_threshold": 0.8           
            }
        
                                                                                                 
        self.default_sentence_adversarial_probability_threshold = self.default_engine_params.get(
            "sentence_adversarial_probability_threshold", 0.8                                           
        )

        self.category_specific_settings = config.category_specific_perplexity_settings
        if not self.category_specific_settings:
            logger.warning("Category-specific perplexity settings not found in config. Will use defaults only.")
            self.category_specific_settings = {}

                                                
        self.dynamic_perplexity_enabled = config.dynamic_perplexity_enabled
        self.dynamic_perplexity_top_k = config.dynamic_perplexity_top_k
        
                                                     
        detection_pipeline_config = config.top_k_semantic_search
        clustering_config_top_k = config.cluster_assignment_top_k
        
                                                          
        if self.dynamic_perplexity_top_k is None:
            if clustering_config_top_k is not None:
                self.dynamic_perplexity_top_k = clustering_config_top_k
            else:
                self.dynamic_perplexity_top_k = detection_pipeline_config
        
        logger.info(f"SecondaryAnalyzer initialized. Default threshold: {self.default_sentence_adversarial_probability_threshold}")
        logger.info(f"Dynamic perplexity enabled: {self.dynamic_perplexity_enabled}, top_k: {self.dynamic_perplexity_top_k}")
        logger.debug(f"Default engine params: {self.default_engine_params}")
        logger.debug(f"Category-specific settings loaded for {len(self.category_specific_settings)} categories.")


    def perform_perplexity_analysis(self, query_features: QueryFeatures) -> DetectionResult:
        if query_features.token_source_log_probabilities is None or not query_features.token_source_log_probabilities:
            logger.warning("Token source log probabilities are missing or empty. Cannot perform perplexity analysis.")
                                                  
            query_features.token_adversarial_probabilities = None
            query_features.sentence_adversarial_probability = None
            return DetectionResult(
                decision=DecisionLabel.AMBIGUOUS,
                confidence=0.0,
                explanation="Perplexity analysis skipped: missing token log probabilities.",
                details={"reason": "Missing token log probabilities"},
                predicted_label_by_perplexity=None,
                confidence_from_perplexity=0.0,
                perplexity_analysis_details="Token log probabilities not available."
            )

                                                                                      
        if (self.dynamic_perplexity_enabled and 
            query_features.prompt_categories_with_weights and 
            len(query_features.prompt_categories_with_weights) > 1):
            
            logger.info("Using dynamic perplexity with weighted majority vote")
            return self._perform_weighted_perplexity_analysis(query_features)
        else:
                                                   
            logger.info("Using single category perplexity analysis")
            return self._perform_single_category_analysis(query_features)

    def _perform_single_category_analysis(self, query_features: QueryFeatures) -> DetectionResult:
        """Original single category analysis logic"""
        prompt_category = query_features.prompt_category
        selected_engine_params = self.default_engine_params
        selected_threshold = self.default_sentence_adversarial_probability_threshold

        if prompt_category and prompt_category in self.category_specific_settings:
            category_settings = self.category_specific_settings[prompt_category]
                                                                                                                           
                                                                
                                                                             
                                                                        
            
                                                                                                                    
            current_category_engine_params = self.default_engine_params.copy()                      
            current_category_engine_params.update({k: v for k, v in category_settings.items() if k in self.default_engine_params})

            selected_engine_params = current_category_engine_params
            selected_threshold = category_settings.get("sentence_adversarial_probability_threshold", self.default_sentence_adversarial_probability_threshold)
            logger.info(f"Using category-specific perplexity parameters and threshold for category '{prompt_category}'. Threshold: {selected_threshold}")
        else:
            logger.info(f"Using default perplexity parameters and threshold. Category '{prompt_category}' not found or not specified. Threshold: {selected_threshold}")

                                               
        token_adv_probs, sentence_adv_prob = self._run_perplexity_engine(
            query_features.token_source_log_probabilities, 
            selected_engine_params
        )

                               
        query_features.token_adversarial_probabilities = token_adv_probs
        query_features.sentence_adversarial_probability = sentence_adv_prob

                       
        return self._make_decision(
            sentence_adv_prob, 
            selected_threshold, 
            token_adv_probs,
            prompt_category,
            selected_engine_params
        )

    def _perform_weighted_perplexity_analysis(self, query_features: QueryFeatures) -> DetectionResult:
        """Performs weighted majority vote across multiple categories"""
        categories_with_weights = query_features.prompt_categories_with_weights
        
                                   
        categories_to_analyze = categories_with_weights[:self.dynamic_perplexity_top_k]
        
        logger.info(f"Analyzing {len(categories_to_analyze)} categories: {categories_to_analyze}")
        
                                         
        category_results = []
        total_weight = sum(weight for _, weight in categories_to_analyze)
        
        for category, raw_weight in categories_to_analyze:
                              
            normalized_weight = raw_weight / total_weight if total_weight > 0 else 0.0
            
                                              
            if category and category in self.category_specific_settings:
                category_settings = self.category_specific_settings[category]
                engine_params = self.default_engine_params.copy()
                engine_params.update({k: v for k, v in category_settings.items() if k in self.default_engine_params})
                threshold = category_settings.get("sentence_adversarial_probability_threshold", self.default_sentence_adversarial_probability_threshold)
            else:
                engine_params = self.default_engine_params
                threshold = self.default_sentence_adversarial_probability_threshold
            
                                            
            token_adv_probs, sentence_adv_prob = self._run_perplexity_engine(
                query_features.token_source_log_probabilities,
                engine_params
            )
            
                                                  
            if sentence_adv_prob >= threshold:
                category_decision = DecisionLabel.UNSAFE
                category_confidence = min(1.0, sentence_adv_prob * 1.2)
            else:
                category_decision = DecisionLabel.SAFE
                category_confidence = min(1.0, (1.0 - sentence_adv_prob) * 1.2)
            
            category_results.append({
                'category': category,
                'weight': normalized_weight,
                'decision': category_decision,
                'confidence': category_confidence,
                'sentence_adv_prob': sentence_adv_prob,
                'threshold': threshold,
                'token_adv_probs': token_adv_probs
            })
        
                                               
        safe_score = 0.0
        unsafe_score = 0.0
        
                                                                                    
        weighted_token_adv_probs = None
        weighted_sentence_adv_prob = 0.0
        
        for result in category_results:
            weight = result['weight']
            if result['decision'] == DecisionLabel.SAFE:
                safe_score += weight * result['confidence']
            else:          
                unsafe_score += weight * result['confidence']
            
                                                                  
            weighted_sentence_adv_prob += weight * result['sentence_adv_prob']
            
                                                                 
            if weighted_token_adv_probs is None:
                weighted_token_adv_probs = [0.0] * len(result['token_adv_probs'])
            for i, token_prob in enumerate(result['token_adv_probs']):
                weighted_token_adv_probs[i] += weight * token_prob
        
                                                     
        query_features.token_adversarial_probabilities = weighted_token_adv_probs
        query_features.sentence_adversarial_probability = weighted_sentence_adv_prob
        
                                  
        if unsafe_score > safe_score:
            final_decision = DecisionLabel.UNSAFE
            final_confidence = unsafe_score / (safe_score + unsafe_score) if (safe_score + unsafe_score) > 0 else 0.0
        else:
            final_decision = DecisionLabel.SAFE
            final_confidence = safe_score / (safe_score + unsafe_score) if (safe_score + unsafe_score) > 0 else 0.0
        
                           
        explanation_parts = []
        for result in category_results:
            explanation_parts.append(
                f"{result['category']} (w={result['weight']:.2f}): {result['decision'].value} "
                f"(SAP={result['sentence_adv_prob']:.3f}, thr={result['threshold']:.3f})"
            )
        
        explanation = (
            f"Dynamic perplexity weighted vote: {final_decision.value}. "
            f"Safe score: {safe_score:.3f}, Unsafe score: {unsafe_score:.3f}. "
            f"Categories analyzed: {', '.join(explanation_parts)}"
        )
        
        return DetectionResult(
            decision=final_decision,
            confidence=final_confidence,
            explanation=explanation,
            details={
                "analysis_type": "weighted_majority_vote",
                "categories_analyzed": len(category_results),
                "safe_score": safe_score,
                "unsafe_score": unsafe_score,
                "weighted_sentence_adversarial_probability": weighted_sentence_adv_prob,
                "category_results": category_results
            },
            predicted_label_by_perplexity=final_decision,
            confidence_from_perplexity=final_confidence,
            perplexity_analysis_details=explanation,
            sentence_adversarial_probability=weighted_sentence_adv_prob
        )

    def _run_perplexity_engine(self, token_source_log_probs: List[float], engine_params: Dict[str, Any]) -> Tuple[List[float], float]:
        """Run perplexity analysis with given engine parameters"""
                                                                                 
        required_engine_keys = [
            "adversarial_token_uniform_log_prob", "lambda_smoothness_penalty",
            "mu_adversarial_token_prior", "apply_first_token_neutral_bias"
        ]
        
        engine_args = {key: engine_params[key] for key in required_engine_keys if key in engine_params}

                                               
        if len(engine_args) != len(required_engine_keys):
            missing_keys = [key for key in required_engine_keys if key not in engine_args]
            logger.error(f"Missing required engine parameters: {missing_keys}. Using defaults.")
            engine_args = {key: self.default_engine_params[key] for key in required_engine_keys}

        engine = PerplexityAnalyzerEngine(**engine_args)
        return engine.calculate_adversarial_probabilities(token_source_log_probs)

    def _make_decision(self, 
                      sentence_adv_prob: float, 
                      threshold: float, 
                      token_adv_probs: List[float],
                      prompt_category: Optional[str],
                      engine_params: Dict[str, Any]) -> DetectionResult:
        """Make decision based on sentence adversarial probability and threshold"""
        logger.debug(f"Calculated sentence adversarial probability: {sentence_adv_prob:.4f} using threshold: {threshold:.4f}")

        if sentence_adv_prob >= threshold:
            decision_label = DecisionLabel.UNSAFE
            explanation_detail = f"Sentence adversarial probability ({sentence_adv_prob:.4f}) met or exceeded threshold ({threshold:.4f})."
        else:
            decision_label = DecisionLabel.SAFE
            explanation_detail = f"Sentence adversarial probability ({sentence_adv_prob:.4f}) is below threshold ({threshold:.4f})."
        
                              
        if decision_label == DecisionLabel.UNSAFE:
            confidence = min(1.0, sentence_adv_prob * 1.2)
        else:
            confidence = min(1.0, (1.0 - sentence_adv_prob) * 1.2)
        
        confidence = round(confidence, 4)

        return DetectionResult(
            decision=decision_label,
            confidence=confidence,
            explanation=f"Perplexity analysis: {decision_label.value}. {explanation_detail}",
            details={
                "prompt_category": prompt_category,
                "selected_threshold": threshold,
                "sentence_adversarial_probability": sentence_adv_prob,
                "token_adversarial_probabilities": token_adv_probs,
                "engine_parameters_used": engine_params
            },
            predicted_label_by_perplexity=decision_label,
            confidence_from_perplexity=confidence,
            perplexity_analysis_details=explanation_detail,
            sentence_adversarial_probability=sentence_adv_prob
        )


if __name__ == '__main__':
                                   
    pass
