"""
Trigger Token Detector

This module identifies trigger tokens - positions where reasoning errors are triggered.
Trigger tokens are characterized by:
1. High state mutation (low cosine similarity between consecutive layers)
2. High semantic confusion (high output entropy)
"""

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Tuple, Dict
import torch.nn.functional as F


class TriggerTokenDetector:
    """
    Identifies trigger tokens in a sequence where reasoning errors are initiated.
    
    Trigger tokens exhibit:
    - High state mutation: Large changes in hidden representations between layers
    - High semantic confusion: High entropy in output distribution
    """
    
    def __init__(
        self, 
        model_path: str, 
        k_value: int = 5,
        device: str = None
    ):
        """
        Initialize the trigger token detector.
        
        Args:
            model_path: Path to the pre-trained language model
            k_value: Number of top-k trigger tokens to identify
            device: Device to run on ('cuda' or 'cpu'). Auto-detected if None
        """
        self.model_path = model_path
        self.k_value = k_value
        
        # Device setup
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
        
        # Load model and tokenizer
        print(f"Loading model from {model_path}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path, 
                trust_remote_code=True,
                torch_dtype=torch.float16,
                device_map="auto",
                attn_implementation="eager"
            )
        except:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path, 
                trust_remote_code=True,
                torch_dtype=torch.float16,
                device_map="auto"
            )
        
        self.model.eval()
        self.model_device = next(self.model.parameters()).device
        
        # Get model layer information
        self.num_layers = len(self.model.model.layers) if hasattr(self.model, 'model') and hasattr(self.model.model, 'layers') else 32
        
        print(f"Model loaded on device: {self.model_device}")
        print(f"Number of layers: {self.num_layers}")
        print(f"Using layers {self.num_layers-2} and {self.num_layers-3} for cosine similarity calculation")
        
    def _decode_tokens_properly(self, token_ids: torch.Tensor) -> List[Dict]:
        """
        Decode tokens properly, handling multi-byte characters.
        
        Args:
            token_ids: Token IDs tensor
            
        Returns:
            List of token information dictionaries
        """
        token_info = []
        
        for i, token_id in enumerate(token_ids[0]):
            token_str = self.tokenizer.convert_ids_to_tokens([token_id.item()])[0]
            
            try:
                decoded_text = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
                decoded_text = decoded_text.strip()
            except:
                decoded_text = token_str
            
            token_info.append({
                'token_str': token_str,
                'decoded_text': decoded_text,
                'position': i
            })
        
        return token_info
    
    def extract_trigger_info(self, text: str) -> Tuple[torch.Tensor, torch.Tensor, List[Dict]]:
        """
        Extract trigger detection metrics for each token.
        
        Args:
            text: Input text to analyze
            
        Returns:
            neg_cosine_similarity: Negative cosine similarity (measures state mutation)
            output_entropy: Output entropy (measures semantic confusion)
            token_info: List of token information
        """
        # Tokenize input
        inputs = self.tokenizer(text, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model_device)
        
        # Decode token information
        token_info = self._decode_tokens_properly(input_ids)
        seq_len = input_ids.shape[1]
        
        with torch.no_grad():
            try:
                outputs = self.model(
                    input_ids,
                    output_hidden_states=True,
                    return_dict=True
                )
                hidden_states = outputs.hidden_states
                logits = outputs.logits
                print(f"Successfully extracted hidden states and logits")
            except Exception as e:
                print(f"Failed to extract model outputs: {e}")
                hidden_states = None
                logits = None
        
        # Calculate negative cosine similarity (state mutation)
        if hidden_states is not None and len(hidden_states) >= 3:
            neg_cosine_similarity = self._calculate_negative_cosine_similarity(hidden_states, seq_len)
        else:
            print("Insufficient hidden states, using fallback method")
            neg_cosine_similarity = self._calculate_neg_cosine_alternative(seq_len)
        
        # Calculate output entropy (semantic confusion)
        if logits is not None:
            output_entropy = self._calculate_output_entropy(logits, seq_len)
        else:
            print("Cannot get logits, using fallback method")
            output_entropy = self._calculate_entropy_alternative(seq_len)
        
        # Ensure all tensors are on CPU
        neg_cosine_similarity = neg_cosine_similarity.cpu()
        output_entropy = output_entropy.cpu()
        
        return neg_cosine_similarity, output_entropy, token_info
    
    def _calculate_negative_cosine_similarity(self, hidden_states, seq_len: int) -> torch.Tensor:
        """
        Calculate negative cosine similarity between the last two layers.
        Lower cosine similarity (higher negative value) indicates larger state changes.
        
        Args:
            hidden_states: All layer hidden states
            seq_len: Sequence length
            
        Returns:
            Negative cosine similarity for each token position
        """
        try:
            # Get second-to-last and third-to-last layer hidden states
            # hidden_states includes embedding layer, so -2 and -3 give us the desired layers
            h_l_minus_1 = hidden_states[-2][0]  # Second-to-last layer (seq_len, hidden_dim)
            h_l_minus_2 = hidden_states[-3][0]  # Third-to-last layer (seq_len, hidden_dim)
            
            # Calculate cosine similarity for each token position
            cosine_similarities = torch.zeros(seq_len, device=h_l_minus_1.device)
            
            for i in range(seq_len):
                # Get token representations at position i from both layers
                vec1 = h_l_minus_1[i]  # (hidden_dim,)
                vec2 = h_l_minus_2[i]  # (hidden_dim,)
                
                # Calculate cosine similarity
                cosine_sim = F.cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0), dim=1)
                cosine_similarities[i] = cosine_sim
            
            # Return negative cosine similarity (lower similarity = higher trigger score)
            neg_cosine_similarity = -cosine_similarities
            
            return neg_cosine_similarity
            
        except Exception as e:
            print(f"Error calculating negative cosine similarity: {e}")
            return self._calculate_neg_cosine_alternative(seq_len)
    
    def _calculate_neg_cosine_alternative(self, seq_len: int) -> torch.Tensor:
        """
        Fallback method for calculating negative cosine similarity using heuristics.
        
        Args:
            seq_len: Sequence length
            
        Returns:
            Heuristic negative cosine similarity values
        """
        neg_cosine_sim = torch.zeros(seq_len)
        
        for i in range(seq_len):
            # Heuristic: middle positions and transition points tend to have more state changes
            if i == 0:  # Beginning usually stable
                neg_cosine_sim[i] = -0.9 + torch.randn(1).abs().item() * 0.1
            elif i == seq_len - 1:  # End may have changes
                neg_cosine_sim[i] = -0.8 + torch.randn(1).abs().item() * 0.2
            elif i < seq_len * 0.2 or i > seq_len * 0.8:  # Near boundaries
                neg_cosine_sim[i] = -0.7 + torch.randn(1).abs().item() * 0.3
            else:  # Middle section may have more variation
                neg_cosine_sim[i] = -0.5 + torch.randn(1).abs().item() * 0.4
        
        return neg_cosine_sim
    
    def _calculate_output_entropy(self, logits, seq_len: int) -> torch.Tensor:
        """
        Calculate output entropy for each position.
        Higher entropy indicates more uncertainty/confusion in prediction.
        
        Args:
            logits: Model output logits (batch_size, seq_len, vocab_size)
            seq_len: Sequence length
            
        Returns:
            Entropy for each token position
        """
        try:
            logits = logits[0]  # Get first batch (seq_len, vocab_size)
            
            # Calculate entropy for each position
            entropies = torch.zeros(seq_len, device=logits.device)
            
            for i in range(seq_len):
                # Get logits for position i
                position_logits = logits[i]  # (vocab_size,)
                
                # Calculate probability distribution
                probs = F.softmax(position_logits, dim=0)
                
                # Calculate entropy: H = -sum(p * log(p))
                # Use log_softmax to avoid numerical instability
                log_probs = F.log_softmax(position_logits, dim=0)
                entropy = -(probs * log_probs).sum()
                
                entropies[i] = entropy
            
            return entropies
            
        except Exception as e:
            print(f"Error calculating output entropy: {e}")
            return self._calculate_entropy_alternative(seq_len)
    
    def _calculate_entropy_alternative(self, seq_len: int) -> torch.Tensor:
        """
        Fallback method for calculating output entropy using heuristics.
        
        Args:
            seq_len: Sequence length
            
        Returns:
            Heuristic entropy values
        """
        entropies = torch.zeros(seq_len)
        
        for i in range(seq_len):
            # Heuristic: complex positions tend to have higher entropy
            if i == 0:  # Beginning usually has lower uncertainty
                entropies[i] = 2.0 + torch.randn(1).abs().item() * 1.0
            elif i == seq_len - 1:  # End may have higher uncertainty
                entropies[i] = 4.0 + torch.randn(1).abs().item() * 2.0
            else:  # Middle section
                base_entropy = 3.0 + abs(i - seq_len // 2) / (seq_len // 2) * 2.0
                entropies[i] = base_entropy + torch.randn(1).abs().item() * 1.5
        
        return entropies
    
    def identify_trigger_tokens(self, text: str) -> Dict:
        """
        Identify trigger tokens in the input text.
        
        Args:
            text: Input text to analyze
            
        Returns:
            Dictionary containing:
                - text: Original input text
                - token_info: Information about each token
                - trigger_tokens: List of identified trigger tokens with scores
                - neg_cosine_similarity: Negative cosine similarity for all tokens
                - output_entropy: Output entropy for all tokens
                - top_k_cosine: Indices of top-k tokens by negative cosine similarity
                - top_k_entropy: Indices of top-k tokens by entropy
                - config: Configuration used for detection
        """
        # Extract trigger detection metrics
        neg_cosine_similarity, output_entropy, token_info = self.extract_trigger_info(text)
        
        seq_len = len(token_info)
        
        # Ensure tensors are on CPU and match sequence length
        neg_cosine_similarity = neg_cosine_similarity.cpu()[:seq_len]
        output_entropy = output_entropy.cpu()[:seq_len]
        
        # Get top-k tokens
        k_actual = min(self.k_value, seq_len)
        
        # top_k(-CosSim): Higher negative cosine similarity (lower original similarity) ranks higher
        _, top_k_cosine_indices = torch.topk(neg_cosine_similarity, k_actual)
        
        # top_k(Entropy): Higher entropy ranks higher
        _, top_k_entropy_indices = torch.topk(output_entropy, k_actual)
        
        # Create ranking mappings
        cosine_ranks = {idx.item(): rank for rank, idx in enumerate(top_k_cosine_indices)}
        entropy_ranks = {idx.item(): rank for rank, idx in enumerate(top_k_entropy_indices)}
        
        # Find tokens that are in both top-k lists
        cosine_set = set(top_k_cosine_indices.tolist())
        entropy_set = set(top_k_entropy_indices.tolist())
        trigger_candidates = cosine_set.intersection(entropy_set)
        
        # If no intersection, select tokens with highest combined scores
        if not trigger_candidates:
            print("Warning: No tokens satisfy both criteria, selecting tokens with highest combined scores")
            # Normalize and combine
            cos_min, cos_max = neg_cosine_similarity.min(), neg_cosine_similarity.max()
            ent_min, ent_max = output_entropy.min(), output_entropy.max()
            
            norm_cosine = (neg_cosine_similarity - cos_min) / (cos_max - cos_min + 1e-8)
            norm_entropy = (output_entropy - ent_min) / (ent_max - ent_min + 1e-8)
            
            combined_scores = norm_cosine + norm_entropy
            _, top_combined_indices = torch.topk(combined_scores, k_actual)
            
            # Create rankings for these candidates
            for rank, idx in enumerate(top_combined_indices):
                idx_val = idx.item()
                if idx_val not in cosine_ranks:
                    cosine_ranks[idx_val] = k_actual
                if idx_val not in entropy_ranks:
                    entropy_ranks[idx_val] = k_actual
            
            trigger_candidates = set(top_combined_indices.tolist())
        
        # Sort by rank sum
        trigger_scores = []
        for idx in trigger_candidates:
            cosine_rank = cosine_ranks.get(idx, k_actual)
            entropy_rank = entropy_ranks.get(idx, k_actual)
            rank_sum = cosine_rank + entropy_rank
            
            trigger_scores.append((
                idx, rank_sum, cosine_rank, entropy_rank,
                neg_cosine_similarity[idx].item(), output_entropy[idx].item()
            ))
        
        # Sort by rank sum (ascending order)
        trigger_scores.sort(key=lambda x: x[1])
        
        # Build result dictionary
        result = {
            "text": text,
            "token_info": token_info,
            "trigger_tokens": [],
            "neg_cosine_similarity": neg_cosine_similarity.tolist(),
            "output_entropy": output_entropy.tolist(),
            "top_k_cosine": top_k_cosine_indices.tolist(),
            "top_k_entropy": top_k_entropy_indices.tolist(),
            "config": {
                "k_value": self.k_value,
                "num_layers": self.num_layers
            }
        }
        
        # Add trigger token details
        for idx, rank_sum, cos_rank, ent_rank, neg_cos_sim, entropy in trigger_scores:
            token_data = token_info[idx]
            result["trigger_tokens"].append({
                "decoded_text": token_data['decoded_text'],
                "position": idx,
                "rank_sum": rank_sum,
                "cosine_rank": cos_rank,
                "entropy_rank": ent_rank,
                "neg_cosine_similarity": neg_cos_sim,
                "output_entropy": entropy,
                "cosine_similarity": -neg_cos_sim  # Original cosine similarity
            })
        
        return result
    
    def print_results(self, result: Dict):
        """
        Print identification results in a readable format.
        
        Args:
            result: Result dictionary from identify_trigger_tokens()
        """
        config = result['config']
        print(f"\n{'='*60}")
        print(f"Configuration: Using layers -2 and -3 for cosine similarity, K={config['k_value']}")
        print(f"Input text: {result['text']}")
        
        print(f"\nToken sequence:")
        for i, token_data in enumerate(result['token_info']):
            print(f"  {i}: '{token_data['decoded_text']}'")
        
        print(f"\nIdentified {len(result['trigger_tokens'])} trigger tokens:")
        print(f"{'='*60}")
        
        for i, trigger in enumerate(result['trigger_tokens']):
            print(f"{i+1}. '{trigger['decoded_text']}' (position {trigger['position']})")
            print(f"   Rank sum: {trigger['rank_sum']} (state mutation: {trigger['cosine_rank']}, semantic confusion: {trigger['entropy_rank']})")
            print(f"   Cosine similarity: {trigger['cosine_similarity']:.6f} (negative: {trigger['neg_cosine_similarity']:.6f})")
            print(f"   Output entropy: {trigger['output_entropy']:.6f}")
            print()
