import asyncio
import torch
import numpy as np
from dataclasses import dataclass
from typing import Dict, Any
from sae_auto_interp.scorers.scorer import Scorer, ScorerResult
from sae_auto_interp.features.features import FeatureRecord
from sae_auto_interp.clients import Client
from sae_auto_interp.logger import logger

@dataclass
class TokenEntropyOutput:
    """Output from token entropy scoring."""
    entropy: float
    total_activations: int
    unique_tokens: int
    token_distribution: Dict[int, float]

class TokenEntropyScorer(Scorer):
    """Scores feature complexity by computing entropy over token activations"""
    
    def __init__(
        self,
        client: Client = None,  # Not used but kept for interface consistency
        tokenizer=None,
        verbose: bool = False,
        min_activation: float = 0.1,  # Increased default threshold to focus on stronger activations
        max_examples: int = 100,  # Limit number of examples processed
        **kwargs
    ):
        self.tokenizer = tokenizer
        self.verbose = verbose
        self.min_activation = min_activation
        self.max_examples = max_examples
        
    def _convert_to_python_types(self, obj):
        """Convert numpy types to native Python types for JSON serialization"""
        if isinstance(obj, dict):
            return {str(k): self._convert_to_python_types(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [self._convert_to_python_types(x) for x in obj]
        elif isinstance(obj, (np.integer, np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, (np.floating, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return self._convert_to_python_types(obj.tolist())
        return obj
        
    async def __call__(self, record: FeatureRecord) -> ScorerResult:
        """Score a feature record by computing token activation entropy"""
        try:
            # Collect activations per token across examples
            token_activations = {}
            examples_processed = 0
            
            # Process examples up to max_examples limit
            for example in record.examples[:self.max_examples]:
                tokens = example.tokens
                activations = example.activations
                
                # Get indices where activation exceeds threshold
                active_indices = torch.where(activations > self.min_activation)[0]
                
                # For each active position
                for idx in active_indices:
                    token_id = int(tokens[idx].item())
                    activation = float(activations[idx].item())
                    
                    if token_id not in token_activations:
                        token_activations[token_id] = 0.0
                    token_activations[token_id] += activation
                
                examples_processed += 1
            
            # If no tokens activated strongly enough
            if not token_activations:
                return ScorerResult(
                    record=record,
                    score={
                        "entropy": 0.0,
                        "total_activations": 0,
                        "unique_tokens": 0,
                        "token_distribution": {},
                        "examples_processed": examples_processed
                    }
                )
            
            # Convert to probability distribution
            total_activation = float(sum(token_activations.values()))
            token_probs = {
                token: float(act/total_activation)
                for token, act in token_activations.items()
            }
            
            # Compute entropy
            entropy = float(-sum(p * np.log(p) for p in token_probs.values()))
            
            # Normalize by max possible entropy
            if self.tokenizer:
                max_possible_entropy = float(np.log(len(self.tokenizer)))
                normalized_entropy = float(entropy / max_possible_entropy)
            else:
                normalized_entropy = float(entropy)
            
            # Prepare readable distribution if verbose
            readable_dist = None
            if self.verbose and self.tokenizer:
                readable_dist = {
                    self.tokenizer.decode([token_id]): float(prob)
                    for token_id, prob in sorted(
                        token_probs.items(), 
                        key=lambda x: x[1], 
                        reverse=True
                    )[:10]  # Only include top 10 tokens in readable form
                }
            
            result = {
                "entropy": normalized_entropy,
                "raw_entropy": entropy,
                "total_activations": total_activation,
                "unique_tokens": len(token_activations),
                "token_distribution": token_probs,
                "examples_processed": examples_processed
            }
            
            if readable_dist:
                result["top_tokens"] = readable_dist
            
            # Convert any remaining numpy types to Python types
            result = self._convert_to_python_types(result)
            
            return ScorerResult(record=record, score=result)
            
        except Exception as e:
            logger.error(f"Error computing token entropy: {e}")
            return ScorerResult(
                record=record,
                score={
                    "error": str(e),
                    "examples_processed": 0
                }
            )
    
    def call_sync(self, record: FeatureRecord) -> ScorerResult:
        """Synchronous wrapper for scoring"""
        return asyncio.run(self.__call__(record))