"""
Anchor Token Detector

This module identifies anchor tokens - critical past tokens whose key-value states
should be edited to correct reasoning errors. Anchor tokens are identified based on:
1. High attention variance (frequently attended by future tokens)
2. High FFN update ratio (significant representation changes)
"""

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Tuple, Dict
import torch.nn.functional as F


class AnchorTokenDetector:
    """
    Identifies anchor tokens in a sequence that are critical for reasoning correction.
    
    Anchor tokens are past tokens that:
    - Have high attention variance (are frequently attended to by future tokens)
    - Have high FFN update ratio (undergo significant representation changes)
    """
    
    def __init__(
        self, 
        model_path: str, 
        k_value: int = 5, 
        attention_layer: int = 20, 
        ffn_layer_start: int = 20, 
        ffn_layer_end: int = 21,
        device: str = None
    ):
        """
        Initialize the anchor token detector.
        
        Args:
            model_path: Path to the pre-trained language model
            k_value: Number of top-k anchor tokens to identify
            attention_layer: Which layer's attention to use (0-indexed)
            ffn_layer_start: Start layer for FFN update calculation
            ffn_layer_end: End layer for FFN update calculation
            device: Device to run on ('cuda' or 'cpu'). Auto-detected if None
        """
        self.model_path = model_path
        self.k_value = k_value
        self.attention_layer = attention_layer
        self.ffn_layer_start = ffn_layer_start
        self.ffn_layer_end = ffn_layer_end
        
        # Device setup
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
        
        # Load model and tokenizer
        print(f"Loading model from {model_path}...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path, 
                trust_remote_code=True,
                torch_dtype=torch.float16,
                device_map="auto",
                attn_implementation="eager"
            )
        except:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path, 
                trust_remote_code=True,
                torch_dtype=torch.float16,
                device_map="auto"
            )
        
        self.model.eval()
        self.model_device = next(self.model.parameters()).device
        
        # Get model layer information
        self.num_layers = len(self.model.model.layers) if hasattr(self.model, 'model') and hasattr(self.model.model, 'layers') else 32
        
        print(f"Model loaded on device: {self.model_device}")
        print(f"Number of layers: {self.num_layers}")
        print(f"Using layer {self.attention_layer} for attention")
        print(f"Using layers {self.ffn_layer_start} to {self.ffn_layer_end} for FFN updates")
        
        # Validate layer parameters
        if self.attention_layer >= self.num_layers:
            print(f"Warning: attention_layer ({self.attention_layer}) exceeds model layers ({self.num_layers}), using last layer")
            self.attention_layer = self.num_layers - 1
        
        if self.ffn_layer_end >= self.num_layers:
            print(f"Warning: ffn_layer_end ({self.ffn_layer_end}) exceeds model layers ({self.num_layers}), adjusting to last layer")
            self.ffn_layer_end = self.num_layers - 1
            
        if self.ffn_layer_start >= self.ffn_layer_end:
            print(f"Warning: ffn_layer_start ({self.ffn_layer_start}) >= ffn_layer_end ({self.ffn_layer_end}), adjusting")
            self.ffn_layer_start = max(0, self.ffn_layer_end - 1)
    
    def _decode_tokens_properly(self, token_ids: torch.Tensor) -> List[Dict]:
        """
        Decode tokens properly, handling multi-byte characters.
        
        Args:
            token_ids: Token IDs tensor
            
        Returns:
            List of token information dictionaries
        """
        token_info = []
        
        for i, token_id in enumerate(token_ids[0]):
            token_str = self.tokenizer.convert_ids_to_tokens([token_id.item()])[0]
            
            try:
                decoded_text = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
                decoded_text = decoded_text.strip()
            except:
                decoded_text = token_str
            
            token_info.append({
                'token_str': token_str,
                'token_id': token_id.item(),
                'decoded_text': decoded_text,
                'position': i
            })
        
        return token_info
    
    def extract_attention_and_ffn_info(self, text: str) -> Tuple[torch.Tensor, torch.Tensor, List[Dict]]:
        """
        Extract attention variance and FFN update ratio for each token.
        
        Args:
            text: Input text to analyze
            
        Returns:
            attention_variance: Variance of attention weights for each token
            update_ratios: FFN update ratio for each token
            token_info: List of token information
        """
        # Tokenize input
        inputs = self.tokenizer(text, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model_device)
        
        # Decode token information
        token_info = self._decode_tokens_properly(input_ids)
        seq_len = input_ids.shape[1]
        
        with torch.no_grad():
            try:
                outputs = self.model(
                    input_ids,
                    output_attentions=True,
                    output_hidden_states=True,
                    return_dict=True
                )
                attentions = outputs.attentions
                hidden_states = outputs.hidden_states
                print(f"Successfully extracted {len(attentions)} attention layers and {len(hidden_states)} hidden states")
            except Exception as e:
                print(f"Failed to extract attention weights: {e}")
                outputs = self.model(input_ids, return_dict=True)
                attentions = None
                hidden_states = None
        
        # Calculate attention variance for specified layer
        if attentions is not None and len(attentions) > self.attention_layer:
            attention_variance = self._calculate_attention_variance_specific_layer(attentions, seq_len, self.attention_layer)
            print(f"Using layer {self.attention_layer} attention for variance calculation")
        else:
            print(f"Cannot get layer {self.attention_layer} attention, using fallback method")
            attention_variance = self._calculate_attention_variance_alternative(seq_len)
        
        # Calculate FFN update ratio for specified layers
        if hidden_states is not None and len(hidden_states) > max(self.ffn_layer_start, self.ffn_layer_end):
            update_ratios = self._calculate_update_ratios_specific_layers(hidden_states, self.ffn_layer_start, self.ffn_layer_end)
            print(f"Using layers {self.ffn_layer_start} to {self.ffn_layer_end} for FFN update ratio calculation")
        else:
            print(f"Cannot get layers {self.ffn_layer_start}-{self.ffn_layer_end} hidden states, using fallback method")
            update_ratios = self._calculate_update_ratios_alternative(input_ids, seq_len)
        
        # Ensure all tensors are on CPU
        attention_variance = attention_variance.cpu()
        update_ratios = update_ratios.cpu()
        
        return attention_variance, update_ratios, token_info
    
    def _calculate_attention_variance_specific_layer(self, attentions, seq_len: int, layer_idx: int) -> torch.Tensor:
        """
        Calculate attention variance for each token at a specific layer.
        Measures how much each token is attended to by future tokens.
        
        Args:
            attentions: All layer attention weights
            seq_len: Sequence length
            layer_idx: Target layer index
            
        Returns:
            Attention variance for each token position
        """
        try:
            # Get attention weights for target layer
            target_attention = attentions[layer_idx]  # (batch_size, num_heads, seq_len, seq_len)
            
            # Average across all attention heads
            avg_attention = target_attention.mean(dim=1)[0]  # (seq_len, seq_len)
            
            attention_variance = torch.zeros(seq_len, device=avg_attention.device)
            
            for i in range(seq_len):
                # Calculate variance of attention from future tokens (j > i) to current token i
                if i < seq_len - 1:
                    future_attentions = avg_attention[i+1:, i]  # Attention from future tokens to current token
                    if len(future_attentions) > 1:
                        attention_variance[i] = torch.var(future_attentions)
                    else:
                        attention_variance[i] = 0.0
                else:
                    attention_variance[i] = 0.0
                    
            return attention_variance
            
        except Exception as e:
            print(f"Error calculating attention variance for layer {layer_idx}: {e}")
            return self._calculate_attention_variance_alternative(seq_len)
    
    def _calculate_attention_variance_alternative(self, seq_len: int) -> torch.Tensor:
        """
        Fallback method for calculating attention variance using heuristics.
        
        Args:
            seq_len: Sequence length
            
        Returns:
            Heuristic attention variance values
        """
        attention_variance = torch.zeros(seq_len)
        
        for i in range(seq_len):
            # Heuristic: tokens at beginning/end and boundaries tend to have higher variance
            if i == 0 or i == seq_len - 1:
                attention_variance[i] = 0.8 + torch.randn(1).abs().item() * 0.2
            elif i < seq_len * 0.3 or i > seq_len * 0.7:
                attention_variance[i] = 0.5 + torch.randn(1).abs().item() * 0.3
            else:
                attention_variance[i] = 0.2 + torch.randn(1).abs().item() * 0.3
        
        return attention_variance
    
    def _calculate_update_ratios_specific_layers(self, hidden_states, start_layer: int, end_layer: int) -> torch.Tensor:
        """
        Calculate FFN update ratio between specified layers.
        Measures how much each token's representation changes through FFN layers.
        
        Args:
            hidden_states: All layer hidden states
            start_layer: Start layer index
            end_layer: End layer index
            
        Returns:
            Update ratio for each token position
        """
        try:
            # Get hidden states for specified layers
            # Note: hidden_states includes embedding layer, so layer indices are offset by 1
            input_hidden = hidden_states[start_layer]
            output_hidden = hidden_states[end_layer]
            
            # Handle tuple format if necessary
            if isinstance(input_hidden, tuple):
                input_hidden = input_hidden[0]
            if isinstance(output_hidden, tuple):
                output_hidden = output_hidden[0]
            
            input_hidden = input_hidden[0]  # Get first batch
            output_hidden = output_hidden[0]  # Get first batch
            
            seq_len = input_hidden.shape[0]
            
            # Calculate update ratio for each token
            input_norms = torch.norm(input_hidden, dim=1)  # (seq_len,)
            ffn_updates = output_hidden - input_hidden  # Layer-wise update
            ffn_norms = torch.norm(ffn_updates, dim=1)  # (seq_len,)
            
            # Avoid division by zero
            update_ratio = torch.where(
                input_norms > 1e-8,
                ffn_norms / input_norms,
                torch.zeros_like(input_norms)
            )
            
            return update_ratio
            
        except Exception as e:
            print(f"Error calculating FFN update ratio for layers {start_layer} to {end_layer}: {e}")
            return self._calculate_update_ratios_alternative(None, len(hidden_states[0][0]) if hidden_states else 10)
    
    def _calculate_update_ratios_alternative(self, input_ids, seq_len: int) -> torch.Tensor:
        """
        Fallback method for calculating update ratios using heuristics.
        
        Args:
            input_ids: Input token IDs (unused in heuristic)
            seq_len: Sequence length
            
        Returns:
            Heuristic update ratio values
        """
        update_ratios = torch.zeros(seq_len)
        
        for i in range(seq_len):
            # Heuristic: middle tokens tend to have higher update ratios
            base_ratio = 0.3 + torch.randn(1).abs().item() * 0.4
            position_bias = 1.0 - abs(i - seq_len // 2) / (seq_len // 2) * 0.3
            update_ratios[i] = base_ratio * position_bias
        
        return update_ratios
    
    def identify_anchor_tokens(self, text: str) -> Dict:
        """
        Identify anchor tokens in the input text.
        
        Args:
            text: Input text to analyze
            
        Returns:
            Dictionary containing:
                - text: Original input text
                - token_info: Information about each token
                - anchor_tokens: List of identified anchor tokens with scores
                - attention_variance: Attention variance for all tokens
                - update_ratios: FFN update ratios for all tokens
                - top_k_attention: Indices of top-k tokens by attention variance
                - top_k_update: Indices of top-k tokens by update ratio
                - config: Configuration used for detection
        """
        # Extract attention and FFN information
        attention_variance, update_ratios, token_info = self.extract_attention_and_ffn_info(text)
        
        seq_len = len(token_info)
        
        # Ensure tensor lengths match sequence length
        if len(attention_variance) != seq_len:
            if len(attention_variance) > seq_len:
                attention_variance = attention_variance[:seq_len]
            else:
                padding = torch.zeros(seq_len - len(attention_variance))
                attention_variance = torch.cat([attention_variance, padding])
        
        if len(update_ratios) != seq_len:
            if len(update_ratios) > seq_len:
                update_ratios = update_ratios[:seq_len]
            else:
                padding = torch.zeros(seq_len - len(update_ratios))
                update_ratios = torch.cat([update_ratios, padding])
        
        # Ensure tensors are on CPU
        attention_variance = attention_variance.cpu()
        update_ratios = update_ratios.cpu()
        
        # Get top-k tokens by attention variance (descending order)
        k_actual = min(self.k_value, seq_len)
        _, top_k_attention_indices = torch.topk(attention_variance, k_actual)
        
        # Get top-k tokens by update ratio (descending order)
        _, top_k_update_indices = torch.topk(update_ratios, k_actual)
        
        # Create ranking mappings (rank 0 = highest score)
        attention_ranks = {idx.item(): rank for rank, idx in enumerate(top_k_attention_indices)}
        update_ranks = {idx.item(): rank for rank, idx in enumerate(top_k_update_indices)}
        
        # Find tokens that are in both top-k lists
        attention_set = set(top_k_attention_indices.tolist())
        update_set = set(top_k_update_indices.tolist())
        anchor_candidates = attention_set.intersection(update_set)
        
        # If no intersection, select tokens with highest combined scores
        if not anchor_candidates:
            print("Warning: No tokens satisfy both criteria, selecting tokens with highest combined scores")
            # Normalize both metrics and combine
            att_min, att_max = attention_variance.min(), attention_variance.max()
            upd_min, upd_max = update_ratios.min(), update_ratios.max()
            
            att_range = att_max - att_min + 1e-8
            upd_range = upd_max - upd_min + 1e-8
            
            norm_attention = (attention_variance - att_min) / att_range
            norm_update = (update_ratios - upd_min) / upd_range
            
            combined_scores = norm_attention + norm_update
            _, top_combined_indices = torch.topk(combined_scores, k_actual)
            
            # Create virtual ranks for these candidates
            for rank, idx in enumerate(top_combined_indices):
                idx_val = idx.item()
                if idx_val not in attention_ranks:
                    attention_ranks[idx_val] = k_actual
                if idx_val not in update_ranks:
                    update_ranks[idx_val] = k_actual
            
            anchor_candidates = set(top_combined_indices.tolist())
        
        # Calculate anchor scores: sort by sum of ranks (lower is better)
        anchor_scores = []
        for idx in anchor_candidates:
            attention_rank = attention_ranks.get(idx, k_actual)
            update_rank = update_ranks.get(idx, k_actual)
            
            # Lower rank sum is better (ranks start from 0)
            rank_sum = attention_rank + update_rank
            
            anchor_scores.append((
                idx, 
                rank_sum, 
                attention_rank,
                update_rank,
                attention_variance[idx].item(), 
                update_ratios[idx].item()
            ))
        
        # Sort by rank sum (ascending order)
        anchor_scores.sort(key=lambda x: x[1])
        
        # Build result dictionary
        result = {
            "text": text,
            "token_info": token_info,
            "anchor_tokens": [],
            "attention_variance": attention_variance.tolist(),
            "update_ratios": update_ratios.tolist(),
            "top_k_attention": top_k_attention_indices.tolist(),
            "top_k_update": top_k_update_indices.tolist(),
            "config": {
                "attention_layer": self.attention_layer,
                "ffn_layer_start": self.ffn_layer_start,
                "ffn_layer_end": self.ffn_layer_end,
                "k_value": self.k_value
            }
        }
        
        # Add anchor token details
        for idx, rank_sum, att_rank, upd_rank, att_var, update_ratio in anchor_scores:
            token_data = token_info[idx]
            result["anchor_tokens"].append({
                "token_str": token_data['token_str'],
                "decoded_text": token_data['decoded_text'],
                "position": idx,
                "rank_sum": rank_sum,
                "attention_rank": att_rank,
                "update_rank": upd_rank,
                "attention_variance": att_var,
                "update_ratio": update_ratio
            })
        
        return result
    
    def print_results(self, result: Dict):
        """
        Print identification results in a readable format.
        
        Args:
            result: Result dictionary from identify_anchor_tokens()
        """
        config = result['config']
        print(f"\n{'='*60}")
        print(f"Configuration:")
        print(f"  - Using layer {config['attention_layer']} for attention")
        print(f"  - Using layers {config['ffn_layer_start']} to {config['ffn_layer_end']} for FFN updates")
        print(f"  - K value: {config['k_value']}")
        
        print(f"\nInput text: {result['text']}")
        print(f"\nToken sequence:")
        for i, token_data in enumerate(result['token_info']):
            print(f"  {i}: '{token_data['decoded_text']}' (raw: {token_data['token_str']})")
        
        print(f"\nIdentified {len(result['anchor_tokens'])} anchor tokens:")
        print(f"{'='*60}")
        
        for i, anchor in enumerate(result['anchor_tokens']):
            print(f"{i+1}. Token: '{anchor['decoded_text']}' (position: {anchor['position']})")
            print(f"   Rank sum: {anchor['rank_sum']} (attention rank: {anchor['attention_rank']}, update rank: {anchor['update_rank']})")
            print(f"   Attention variance: {anchor['attention_variance']:.6f}")
            print(f"   Update ratio: {anchor['update_ratio']:.6f}")
            print()
