from typing import Any, List, Dict
from src.feature import Example, Feature
import random
import numpy as np

from src.samplers import BaseSampler

class PromptBuilder:
    def __init__(
        self,
        *,
        tokenizer: Any,  # Should have a decode() method
        max_example_length: int = 50,
        sampler: BaseSampler,
        system_prompt_template: str = None,
        verbose: bool = False,
        include_statistics: bool = False,
        context_window: int = 5  # tokens around activation to show
    ):
        self.tokenizer = tokenizer
        self.max_example_length = max_example_length
        self.sampler = sampler
        self.verbose = verbose
        self.include_statistics = include_statistics
        self.context_window = context_window
        
        self.system_prompt_template = system_prompt_template or """You are a meticulous AI researcher conducting an important investigation into patterns found in language. Your task is to analyze text and provide an explanation that thoroughly encapsulates possible patterns found in it.

Guidelines:

You will be given a list of text examples on which special words are selected and between delimiters like <<this>>. If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <<just like this>>. How important each token is for the behavior is listed after each example in parentheses.

- Provide your explanation as a concise STANDALONE PHRASE describing the common contexts or concepts (e.g., "legal terminology in contracts" not "The feature detects legal terminology")
- Do not make lists of possible explanations
- Focus on the essence of what patterns are present in the examples
- Do NOT mention the texts, examples or the feature itself in your explanation
- Do NOT write "these texts", "feature detects", "the patterns suggest" or something like that
- The last line must be exactly formatted as: [EXPLANATION]: your description

"""
    
    def _get_system_prompt(self) -> str:
        """Returns the system prompt based on configuration."""
        return self.system_prompt_template
    
    def _format_example(self, example: Example) -> str:
        """Simplified example formatter with smart truncation and activation highlighting."""
        # Decode all tokens at once
        tokens = [self.tokenizer.decode([token_id]) for token_id in example.context]
        
        # Create activation mapping
        activation_map = dict(zip(example.activation_positions, example.activation_values))
        
        # Determine truncation point
        if len(tokens) > self.max_example_length:
            # Find the last activation before max length
            last_activation = max([p for p in example.activation_positions 
                                 if p < self.max_example_length], default=-1)
            
            # Extend window if activation is near the cutoff
            if last_activation > 0.8 * self.max_example_length:
                truncate_at = min(len(tokens), last_activation + self.context_window)
            else:
                truncate_at = self.max_example_length
            tokens = tokens[:truncate_at]
        
        # Build marked text and collect activation info
        marked_text = []
        activation_info = []
        
        for i, token in enumerate(tokens):
            if i in activation_map:
                # Highlight activated tokens
                marked_text.append(f"<<{token}>>")
                activation_info.append(f"{token} ({activation_map[i]:.2f})")
            else:
                marked_text.append(token)

        if len(activation_info) > 5:
            activation_info = random.sample(activation_info, 5)
        
        # Add truncation indicator if needed
        if len(tokens) < len(example.context):
            marked_text.append("[...]")
        
        # Combine into final output
        text = ''.join(marked_text).replace(">><<", "")
        activations = ', '.join(activation_info) if activation_info else "None"
        
        return f"Text: {repr(text)}\nActivations: {repr(activations)}"
    
    def _get_user_prompt(self, feature: Feature) -> str:
        """Constructs the user prompt with sampled examples and optional statistics."""
        sampled_examples = self.sampler.sample(feature.positive_examples)
        formatted_examples = [self._format_example(ex) for ex in sampled_examples]
        
        prompt_parts = []
        
        if self.verbose:
            prompt_parts.append(f"Analyzing feature {feature.index} with {len(feature.examples)} total examples.")
        
        prompt_parts.append("Examples of activations:")
        prompt_parts.extend(formatted_examples)
        
        return "\n\n".join(prompt_parts)
    
    def build_prompt(self, feature: Feature) -> Dict[str, str]:
        """Returns complete prompt with system and user parts."""
        return {
            "system": self._get_system_prompt(),
            "user": self._get_user_prompt(feature)
        }
    
    def parse_explanation(self, llm_output: str) -> str:
        """Extracts the explanation from LLM output."""
        if "[EXPLANATION]:" in llm_output:
            return llm_output.split("[EXPLANATION]:")[-1].strip()
        return llm_output.strip()