from dataclasses import dataclass
from collections import defaultdict
import numpy as np
from transformers import PreTrainedTokenizerBase

from .collector import FeatureActivations


@dataclass
class TokenFeatureAssociation:
    token_id: int
    token_str: str
    feature_index: int
    
    mean_activation: float
    max_activation: float
    occurrence_count: int
    occurrence_count_reasoning: int
    occurrence_count_nonreasoning: int
    pmi: float
    activation_ratio: float
    mean_activation_in_reasoning: float
    mean_activation_in_nonreasoning: float

    def to_dict(self) -> dict:
        return {
            "token_id": self.token_id,
            "token_str": self.token_str,
            "feature_index": self.feature_index,
            "mean_activation": self.mean_activation,
            "max_activation": self.max_activation,
            "occurrence_count": self.occurrence_count,
            "occurrence_count_reasoning": self.occurrence_count_reasoning,
            "occurrence_count_nonreasoning": self.occurrence_count_nonreasoning,
            "pmi": self.pmi,
            "activation_ratio": self.activation_ratio,
            "mean_activation_in_reasoning": self.mean_activation_in_reasoning,
            "mean_activation_in_nonreasoning": self.mean_activation_in_nonreasoning,
        }


@dataclass
class NgramFeatureAssociation:
    token_ids: tuple
    token_strs: tuple
    ngram_str: str
    feature_index: int
    n: int
    mean_activation: float
    max_activation: float
    occurrence_count: int
    occurrence_count_reasoning: int
    occurrence_count_nonreasoning: int
    mean_activation_in_reasoning: float
    mean_activation_in_nonreasoning: float
    
    def to_dict(self) -> dict:
        """Convert to dictionary."""
        return {
            "token_ids": list(self.token_ids),
            "token_strs": list(self.token_strs),
            "ngram_str": self.ngram_str,
            "feature_index": self.feature_index,
            "n": self.n,
            "mean_activation": self.mean_activation,
            "max_activation": self.max_activation,
            "occurrence_count": self.occurrence_count,
            "occurrence_count_reasoning": self.occurrence_count_reasoning,
            "occurrence_count_nonreasoning": self.occurrence_count_nonreasoning,
            "mean_activation_in_reasoning": self.mean_activation_in_reasoning,
            "mean_activation_in_nonreasoning": self.mean_activation_in_nonreasoning,
        }


class TopTokenAnalyzer:
    
    def __init__(
        self,
        activations: FeatureActivations,
        tokenizer: PreTrainedTokenizerBase,
        activation_threshold: float = 0.1,
    ):
        self.activations = activations
        self.tokenizer = tokenizer
        self.activation_threshold = activation_threshold
        
        self.reasoning_mask = activations.get_reasoning_mask()
        self._precompute_token_stats()
    
    def _precompute_token_stats(self):
        tokens = self.activations.tokens.numpy()
        
        self.token_counts = defaultdict(int)
        self.token_positions = defaultdict(list)
        
        for sample_idx in range(tokens.shape[0]):
            for pos in range(tokens.shape[1]):
                token_id = tokens[sample_idx, pos]
                self.token_counts[token_id] += 1
                self.token_positions[token_id].append((sample_idx, pos))
        
        self.total_tokens = tokens.size
        self.unique_tokens = len(self.token_counts)
    
    def get_top_tokens_for_feature(
        self,
        feature_index: int,
        top_k: int = 50,
        min_occurrences: int = 5,
    ) -> list[TokenFeatureAssociation]:
        acts = self.activations.activations[:, :, feature_index].numpy()
        tokens = self.activations.tokens.numpy()
        reasoning_mask = self.reasoning_mask.numpy()
        
        feature_max = acts.max()
        threshold = self.activation_threshold * max(feature_max, 1e-10)
        feature_fires_prob = (acts > threshold).mean()
        
        associations = []
        
        for token_id, positions in self.token_positions.items():
            if len(positions) < min_occurrences:
                continue
            
            token_acts = []
            token_acts_reasoning = []
            token_acts_nonreasoning = []
            count_reasoning = 0
            count_nonreasoning = 0
            
            for sample_idx, pos in positions:
                act = acts[sample_idx, pos]
                token_acts.append(act)
                
                if reasoning_mask[sample_idx]:
                    token_acts_reasoning.append(act)
                    count_reasoning += 1
                else:
                    token_acts_nonreasoning.append(act)
                    count_nonreasoning += 1
            
            token_acts = np.array(token_acts)
            
            mean_act = float(np.mean(token_acts))
            max_act = float(np.max(token_acts))
            
            p_token = len(positions) / self.total_tokens
            p_feature_fires = feature_fires_prob
            p_joint = (token_acts > threshold).mean()
            
            if p_joint > 0 and p_token > 0 and p_feature_fires > 0:
                pmi = np.log2(p_joint / (p_token * p_feature_fires))
            else:
                pmi = -np.inf
            
            p_feature_given_token = (token_acts > threshold).mean()
            if p_feature_fires > 0:
                activation_ratio = p_feature_given_token / p_feature_fires
            else:
                activation_ratio = 0.0
            
            mean_reasoning = np.mean(token_acts_reasoning) if token_acts_reasoning else 0.0
            mean_nonreasoning = np.mean(token_acts_nonreasoning) if token_acts_nonreasoning else 0.0
            
            try:
                token_str = self.tokenizer.decode([token_id])
            except Exception:
                token_str = f"<token_{token_id}>"
            
            associations.append(TokenFeatureAssociation(
                token_id=int(token_id),
                token_str=token_str,
                feature_index=feature_index,
                mean_activation=mean_act,
                max_activation=max_act,
                occurrence_count=len(positions),
                occurrence_count_reasoning=count_reasoning,
                occurrence_count_nonreasoning=count_nonreasoning,
                pmi=float(pmi) if not np.isinf(pmi) else -100.0,
                activation_ratio=float(activation_ratio),
                mean_activation_in_reasoning=float(mean_reasoning),
                mean_activation_in_nonreasoning=float(mean_nonreasoning),
            ))
        
        associations.sort(key=lambda x: x.mean_activation, reverse=True)
        
        return associations[:top_k]
    
    def get_reasoning_specific_tokens(
        self,
        feature_index: int,
        top_k: int = 30,
        min_occurrences: int = 5,
    ) -> list[TokenFeatureAssociation]:
        all_tokens = self.get_top_tokens_for_feature(
            feature_index, top_k=top_k * 3, min_occurrences=min_occurrences
        )
        
        reasoning_tokens = []
        for assoc in all_tokens:
            if assoc.mean_activation_in_reasoning < 0.01:
                continue
            
            specificity = assoc.mean_activation_in_reasoning / (
                assoc.mean_activation_in_nonreasoning + 1e-10
            )
            
            if specificity > 1.5:
                reasoning_tokens.append((assoc, specificity))
        
        reasoning_tokens.sort(key=lambda x: x[1], reverse=True)
        
        return [assoc for assoc, _ in reasoning_tokens[:top_k]]
    
    def analyze_feature_token_dependency(
        self,
        feature_index: int,
        top_k_tokens: int = 20,
    ) -> dict:
        acts = self.activations.activations[:, :, feature_index].numpy()
        tokens = self.activations.tokens.numpy()
        
        top_tokens = self.get_top_tokens_for_feature(feature_index, top_k=top_k_tokens)
        top_token_ids = {t.token_id for t in top_tokens}
        
        threshold = self.activation_threshold * max(acts.max(), 1e-10)
        high_act_mask = acts > threshold
        
        high_acts_from_top_tokens = 0
        total_high_acts = high_act_mask.sum()
        
        for sample_idx in range(acts.shape[0]):
            for pos in range(acts.shape[1]):
                if high_act_mask[sample_idx, pos]:
                    if tokens[sample_idx, pos] in top_token_ids:
                        high_acts_from_top_tokens += 1
        
        if total_high_acts > 0:
            token_concentration = high_acts_from_top_tokens / total_high_acts
        else:
            token_concentration = 0.0
        
        token_act_sums = defaultdict(float)
        for sample_idx in range(acts.shape[0]):
            for pos in range(acts.shape[1]):
                token_id = tokens[sample_idx, pos]
                token_act_sums[token_id] += acts[sample_idx, pos]
        
        total_act = sum(token_act_sums.values())
        if total_act > 0:
            probs = np.array([v / total_act for v in token_act_sums.values()])
            entropy = -np.sum(probs * np.log2(probs + 1e-10))
            max_entropy = np.log2(len(token_act_sums))
            normalized_entropy = entropy / max(max_entropy, 1e-10)
        else:
            normalized_entropy = 0.0
        
        return {
            "feature_index": feature_index,
            "top_tokens": [t.to_dict() for t in top_tokens],
            "token_concentration": token_concentration,
            "normalized_entropy": normalized_entropy,
            "is_token_dependent": bool(token_concentration > 0.5),
            "interpretation": (
                "HIGH token dependency - likely shallow cue" 
                if token_concentration > 0.5 
                else "LOWER token dependency - may capture deeper patterns"
            ),
        }
    
    def get_top_ngrams_for_feature(
        self,
        feature_index: int,
        n: int = 2,
        top_k: int = 30,
        min_occurrences: int = 3,
    ) -> list[NgramFeatureAssociation]:
        acts = self.activations.activations[:, :, feature_index].numpy()
        tokens = self.activations.tokens.numpy()
        reasoning_mask = self.reasoning_mask.numpy()
        
        ngram_stats = defaultdict(lambda: {
            'mean_acts': [], 'max_acts': [],
            'reasoning_acts': [], 'nonreasoning_acts': [],
        })
        
        for sample_idx in range(tokens.shape[0]):
            is_reasoning = reasoning_mask[sample_idx]
            seq_len = tokens.shape[1]
            
            for pos in range(seq_len - n + 1):
                ngram_ids = tuple(int(tokens[sample_idx, pos + i]) for i in range(n))
                ngram_acts = acts[sample_idx, pos:pos + n]
                mean_act = float(np.mean(ngram_acts))
                max_act = float(np.max(ngram_acts))
                
                stats = ngram_stats[ngram_ids]
                stats['mean_acts'].append(mean_act)
                stats['max_acts'].append(max_act)
                
                if is_reasoning:
                    stats['reasoning_acts'].append(mean_act)
                else:
                    stats['nonreasoning_acts'].append(mean_act)
        
        associations = []
        for ngram_ids, stats in ngram_stats.items():
            if len(stats['mean_acts']) < min_occurrences:
                continue
            
            try:
                token_strs = tuple(self.tokenizer.decode([tid]) for tid in ngram_ids)
                ngram_str = ''.join(token_strs)
            except Exception:
                token_strs = tuple(f"<token_{tid}>" for tid in ngram_ids)
                ngram_str = ' '.join(token_strs)
            
            associations.append(NgramFeatureAssociation(
                token_ids=ngram_ids,
                token_strs=token_strs,
                ngram_str=ngram_str,
                feature_index=feature_index,
                n=n,
                mean_activation=float(np.mean(stats['mean_acts'])),
                max_activation=float(np.max(stats['max_acts'])),
                occurrence_count=len(stats['mean_acts']),
                occurrence_count_reasoning=len(stats['reasoning_acts']),
                occurrence_count_nonreasoning=len(stats['nonreasoning_acts']),
                mean_activation_in_reasoning=float(np.mean(stats['reasoning_acts'])) if stats['reasoning_acts'] else 0.0,
                mean_activation_in_nonreasoning=float(np.mean(stats['nonreasoning_acts'])) if stats['nonreasoning_acts'] else 0.0,
            ))
        
        associations.sort(key=lambda x: x.mean_activation, reverse=True)
        return associations[:top_k]
    
    def get_feature_vocabulary(
        self,
        feature_index: int,
        activation_percentile: float = 90,
    ) -> list[str]:
        acts = self.activations.activations[:, :, feature_index].numpy()
        tokens = self.activations.tokens.numpy()
        
        threshold = np.percentile(acts[acts > 0], activation_percentile) if (acts > 0).any() else 0
        
        high_act_tokens = set()
        for sample_idx in range(acts.shape[0]):
            for pos in range(acts.shape[1]):
                if acts[sample_idx, pos] > threshold:
                    high_act_tokens.add(tokens[sample_idx, pos])
        
        vocab = []
        for token_id in high_act_tokens:
            try:
                vocab.append(self.tokenizer.decode([token_id]))
            except Exception:
                pass
        
        return vocab
