from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Union, Iterator, Tuple
import textwrap
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.ticker import LogLocator, LogFormatter
import random

import pickle
from pathlib import Path
from dataclasses import dataclass, asdict, field

import math
import torch
from collections import Counter

@dataclass
class Example:
    context: List[int]
    activation_positions: List[int]
    activation_values: List[float]
    is_positive: bool = True
    

@dataclass
class FeatureStatistics:
    min: float
    max: float
    mean: float
    histogram_counts: List[float]
    histogram_edges: List[float]
    positional_hist: List[float]
    gini_coeff: float
    intra_seq_density: float
    activation_freq: float
    sequence_penetration: float
    pos_mean: float
    pos_std: float
    multitoken_ratio: float
    frequency: float
    token_entropy: Optional[float] = None
    num_unique_tokens: Optional[int] = None
    mean_activations_per_token: Optional[float] = None
    cosine_sim_hist: Optional[List[int]] = None
    cosine_sim_mean: Optional[float] = None
    cosine_sim_std: Optional[float] = None
    top_tokens: Optional[List[Tuple[int, float]]] = None
    bot_tokens: Optional[List[Tuple[int, float]]] = None
    cosine_with_activated_tokens: float = None
    

@dataclass
class InterpretationScore:
    value: float
    commentary: str = None
    

@dataclass
class Interpretation:
    value: str
    commentary: str = None
    score: InterpretationScore = None
    

@dataclass
class Feature:
    index: int = None
    examples: List[Example] = field(default_factory=list)
    statistics: FeatureStatistics = None
    interpretations: List[Interpretation] = field(default_factory=list)

    @property
    def positive_examples(self) -> List[Example]:
        return [ex for ex in self.examples if ex.is_positive]
    
    @property
    def negative_examples(self) -> List[Example]:
        return [ex for ex in self.examples if not ex.is_positive]

    def split_examples(self, context_window: int, delimiter: int = -1) -> List[Example]:
        """Split stored examples into context windows around activations"""
        new_examples = []
        for example in self.positive_examples:
            # Group activations into clusters within context_window
            groups = self._group_activations(example.activation_positions, context_window)
            
            # Create merged windows for each group
            for group in groups:
                if not group:
                    continue
                    
                # Calculate window bounds
                first_pos = min(group)
                last_pos = max(group)
                center = (first_pos + last_pos) // 2
                start = max(0, center - context_window//2)
                end = min(len(example.context), center + context_window//2)
                
                # Extract window and adjust positions
                window = example.context[start:end]
                adjusted_pos = [pos - start for pos in group if start <= pos < end]
                
                if adjusted_pos:
                    new_examples.append(Example(
                        context=window,
                        activation_positions=adjusted_pos,
                        activation_values=[
                            example.activation_values[i] 
                            for i, pos in enumerate(example.activation_positions) 
                            if start <= pos < end
                        ],
                        is_positive=True
                    ))

        # Process negative examples with random windows
        for example in self.negative_examples:
            if len(example.context) > context_window:
                start = torch.randint(0, len(example.context) - context_window, (1,)).item()
                start = max(0, start)
                window = example.context[start:start+context_window]
                new_examples.append(Example(
                    context=window,
                    activation_positions=[],
                    activation_values=[],
                    is_positive=False
                ))
            else:
                new_examples.append(example)

        return new_examples

    def _group_activations(self, positions: List[int], window_size: int) -> List[List[int]]:
        """Group activation positions into clusters within window_size"""
        positions = sorted(positions)
        groups = []
        current_group = []
        
        for pos in positions:
            if not current_group or pos - current_group[-1] <= window_size:
                current_group.append(pos)
            else:
                groups.append(current_group)
                current_group = [pos]
        if current_group:
            groups.append(current_group)
        return groups

    def compute_token_statistics(self) -> None:
        activated_tokens = self._get_all_activated_tokens()
        
        # Compute basic token statistics
        token_stats = self._compute_token_statistics(activated_tokens)
        
        self.statistics.token_entropy = token_stats['entropy']
        self.statistics.num_unique_tokens = token_stats['num_unique']
        self.statistics.mean_activations_per_token = token_stats['mean_activations']

    def compute_cosine_statistics(self, feature_embedding: torch.Tensor, embedding_matrix: torch.Tensor, cosine_bins: int = 20) -> None:
        """Compute additional token-level statistics and embedding analysis"""
        activated_tokens = self._get_all_activated_tokens()
        
        cos_sim_hist, cosine_sim_mean, cosine_sim_std = self._compute_cosine_similarities(
            activated_tokens, embedding_matrix, cosine_bins
        )

        self.statistics.cosine_sim_hist = cos_sim_hist
        self.statistics.cosine_sim_mean = cosine_sim_mean
        self.statistics.cosine_sim_std = cosine_sim_std
        
        cosine = (feature_embedding @ embedding_matrix.T).flatten()
        top_tokens_val, top_tokens_idx = torch.topk(cosine, k = 10, largest = True)
        bot_tokens_val, bot_tokens_idx = torch.topk(cosine, k = 10, largest = False)
        self.statistics.cosine_with_activated_tokens = cosine[activated_tokens].mean().item()
        
        self.statistics.top_tokens = [(idx.item(), val.item()) for idx, val in zip(top_tokens_idx, top_tokens_val)]
        self.statistics.bot_tokens = [(idx.item(), val.item()) for idx, val in zip(bot_tokens_idx, bot_tokens_val)]

    def _get_all_activated_tokens(self) -> List[int]:
        """Extract all activated tokens from positive examples"""
        activated_tokens = []
        for example in self.positive_examples:
            context = example.context
            for pos in example.activation_positions:
                if pos < len(context):  # Guard against invalid positions
                    activated_tokens.append(context[pos])
        return activated_tokens

    def _compute_token_statistics(self, tokens: List[int]) -> Dict:
        """Compute entropy and uniqueness metrics"""
        if not tokens:
            return {'entropy': 0.0, 'num_unique': 0, 'mean_activations': 0.0}
        
        counter = Counter(tokens)
        total = len(tokens)
        num_unique = len(counter)
        
        # Entropy calculation
        probs = [count/total for count in counter.values()]
        entropy = -sum(p * math.log(p) for p in probs if p > 0)
        
        return {
            'entropy': entropy,
            'num_unique': num_unique,
            'mean_activations': total/num_unique if num_unique else 0.0
        }

    def _compute_cosine_similarities(self, tokens: List[int], 
                                   embedding_matrix: torch.Tensor,
                                   bins: int) -> Tuple[List[int], float, float]:
        """Compute histogram of pairwise cosine similarities between embeddings"""

        if embedding_matrix is None or len(tokens) == 0:
            return [0]*bins, None, None
        
        # Filter valid tokens and get unique embeddings
        unique_tokens = list(set(t for t in tokens if t < embedding_matrix.shape[0]))
        if len(unique_tokens) < 2:  # Need at least 2 for pairwise comparisons
            return [0]*bins, None, None
        
        embeddings = embedding_matrix[unique_tokens]
        
        # All pairwise similarities
        cosine_sims = torch.mm(embeddings, embeddings.T)
        
        # Exclude self-similarities and flatten
        mask = torch.tril(torch.ones_like(cosine_sims, dtype=bool), -1)
        filtered_sims = cosine_sims[mask].cpu().numpy()
        
        # Bin the similarities
        hist, _ = np.histogram(filtered_sims, bins=bins, range=(-1.0, 1.0))
        return hist.tolist(), filtered_sims.mean().item(), filtered_sims.std().item()

    def plot_activation_distribution(self, figsize=(10, 5), fontsize=12, show_annotations=True, show_collected=True):
        # Step 1: Extract activation values from examples
        example_activations = []
        for ex in self.examples:
            if ex.is_positive:
                example_activations.extend(ex.activation_values)
        
        # Step 2: Retrieve histogram data from FeatureStatistics
        hist_counts = np.array(self.statistics.histogram_counts)
        hist_edges = np.array(self.statistics.histogram_edges)
        
        # Step 3: Normalize histogram counts to represent a probability density
        bin_widths = np.diff(hist_edges)
        hist_density = hist_counts[1:] / (bin_widths * hist_counts.sum())
        
        # Step 4: Plot the distribution of activation values from examples
        plt.figure(figsize=figsize)
        
        # Plot example activations
        sns.histplot(
            x=example_activations, 
            # bins=hist_edges, 
            stat='density',
            kde=False,
            color='#1f77b4',  # Muted blue
            label='Example Activations',
            edgecolor='black', 
            linewidth=0.5
        )
        
        # Overlay histogram from FeatureStatistics
        if show_collected:
            plt.stairs(
                hist_density, 
                hist_edges, 
                color='#ff7f0e',  # Muted orange
                linewidth=2, 
                label='Collected Histogram'
            )
        
        # Add labels and legend
        plt.xscale("log")
        plt.title(
            f"Activation Distribution for Feature {self.index}",
            fontsize=fontsize,
            pad=20
        )
        plt.xlabel("Activation Value (log scale)", fontsize=fontsize)
        plt.ylabel("Density", fontsize=fontsize)
        plt.legend(fontsize=fontsize, frameon=False)
        
        # Improve grid and ticks
        plt.grid(True, linestyle='--', alpha=0.6, which='both')
        plt.gca().xaxis.set_major_locator(LogLocator(base=10, numticks=15))
        plt.gca().xaxis.set_major_formatter(LogFormatter(labelOnlyBase=False))
        plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x:.2f}"))
        
        # Add vertical lines for min, max, and mean with annotations
        plt.axvline(self.statistics.min, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
        plt.axvline(self.statistics.max, color='green', linestyle='--', linewidth=1.5, alpha=0.7)
        plt.axvline(self.statistics.mean, color='purple', linestyle='--', linewidth=1.5, alpha=0.7)

        if show_annotations:
            plt.text(
                self.statistics.min * 1.1,  # Slightly offset from the line
                plt.ylim()[1] * 0.9,           # Near the top of the plot
                f'Min: {self.statistics.min:.2f}',
                color='red',
                fontsize=fontsize,
                rotation=90,
                verticalalignment='top'
            )
        
            plt.text(
                self.statistics.max * 1.1,   # Slightly offset from the line
                plt.ylim()[1] * 0.8,           # Below the min annotation
                f'Max: {self.statistics.max:.2f}',
                color='green',
                fontsize=fontsize,
                rotation=90,
                verticalalignment='top'
            )
            
            plt.text(
                self.statistics.mean * 1.1,  # Slightly offset from the line
                plt.ylim()[1] * 0.7,           # Below the max annotation
                f'Mean: {self.statistics.mean:.2f}',
                color='purple',
                fontsize=fontsize,
                rotation=90,
                verticalalignment='top'
            )
        
        # Adjust layout
        plt.tight_layout()
        plt.show()

    def show(
        self,
        tokenizer: Any,
        context_window: int,
        max_example_length: int = 60,
        examples_per_quantile: int = 2,
        include_statistics: bool = True,
        show_only_top: bool = False
    ) -> str:
        report = []
        
        # Basic info header
        report.append(f"Feature {self.index}")
        if self.interpretations:
            interpretations = '\n'.join([i.value for i in self.interpretations])
            report.append(f"\nInterpretations:\n{interpretations}")

        if include_statistics:
            # Enhanced statistics section
            stats = self._format_statistics(tokenizer)
            report.append(stats)
        
        # Generate and format splitted examples
        splitted_examples = self.split_examples(context_window)
        examples_report = self._format_examples(
            splitted_examples,
            tokenizer,
            max_example_length,
            examples_per_quantile,
            show_only_top
        )
        report.append(examples_report)
        
        return "\n".join(report)

    def _format_statistics(self, tokenizer: Any) -> str:
        s = self.statistics
    
        def fmt(value, format_spec):
            if value is None:
                return "N/A"
            try:
                return format(value, format_spec)
            except (TypeError, ValueError):
                return str(value)
    
        lines = [
            "\nStatistics:",
            f"Activation Range: {fmt(s.min, '.2f')} - {fmt(s.max, '.2f')} (μ={fmt(s.mean, '.2f')})",
            "Token Diversity:",
            f"  - Unique tokens: {fmt(s.num_unique_tokens, '')}",
            f"  - Entropy: {fmt(s.token_entropy, '.2f')}",
            "Activation Distribution:",
            f"  - Frequency: {fmt(s.frequency, '.2e')} (per token)",
            f"  - Per Sequence: {fmt(s.sequence_penetration, '.1%')} of sequences",
            "Context Patterns:",
            f"  - Position: μ={fmt(s.pos_mean, '.2f')} ± {fmt(s.pos_std, '.2f')}",
            f"  - Multitoken ratio: {fmt(s.multitoken_ratio, '.2f')}",
            "Embedding Space:",
            f"  - Cossim between active tokens: μ={fmt(s.cosine_sim_mean, '.2f')} ± {fmt(s.cosine_sim_std, '.2f')}",
            f"  - Cossim with active tokens: μ={fmt(s.cosine_with_activated_tokens, '.2f')}"
        ]
    
        # Add token lists if available
        if s.top_tokens:
            top_tokens = ', '.join([
                f"{repr(tokenizer.decode([tid]))} ({fmt(count, '.1f')})"
                for tid, count in s.top_tokens
            ])
            lines.append(f"\nTop Tokens: {top_tokens}")
        if s.bot_tokens:
            bot_tokens = ', '.join([
                f"{repr(tokenizer.decode([tid]))} ({fmt(count, '.1f')})"
                for tid, count in s.bot_tokens
            ])
            lines.append(f"Bottom Tokens: {bot_tokens}")
    
        return "\n".join(lines)

    def _format_examples(
        self,
        examples: List[Example],
        tokenizer: Any,
        max_length: int,
        n_examples: int,
        show_only_top: bool
    ) -> str:
        if not examples:
            return "\n\nNo examples available"
            
        # Sort by maximum activation value
        sorted_examples = sorted(
            [ex for ex in examples if ex.is_positive],
            key=lambda x: max(x.activation_values, default=0),
            reverse=True
        )
        
        # Create activation quantiles
        if show_only_top:
            quantile_groups = {
                'Top 75% - 100%': sorted_examples[:len(sorted_examples)//4],
            }
        else:
            quantile_groups = {
                'Top 75% - 100%': sorted_examples[:len(sorted_examples)//4],
                'Medium 25% - 75%': sorted_examples[len(sorted_examples)//4:3*len(sorted_examples)//4],
                'Bottom 0% - 25%': sorted_examples[3*len(sorted_examples)//4:]
            }
        
        examples_report = ["\nContext Examples:"]
        for group_name, group_examples in quantile_groups.items():
            if not group_examples:
                continue
                
            samples = sorted(
                random.sample(group_examples, min(n_examples, len(group_examples))),
                key=lambda x: max(x.activation_values, default=0),
                reverse=True
            )
            examples_report.append(f"\n{group_name}:")
            
            for ex in samples:
                # Build context string with activation highlights
                tokens = []
                activation_map = dict(zip(ex.activation_positions, ex.activation_values))
                
                for i, token_id in enumerate(ex.context):
                    decoded = tokenizer.decode([token_id])
                    if i in activation_map:
                        tokens.append(f"<<{decoded} ({activation_map[i]:.1f})>>")
                    else:
                        tokens.append(decoded)
                
                # Merge and truncate
                context_str = ''.join(tokens).replace(">><<", "")
                if len(context_str) > max_length:
                    context_str = context_str[:max_length] + "..."

                mean_activation = np.mean(ex.activation_values)
                examples_report.append(f"- ({mean_activation:.1f}) - {repr(context_str)[1:-1]}")

        return '\n'.join(examples_report)

    def __repr__(self) -> str:
        if self.interpretations is not None:
            interpretation = self.interpretations[-1]
            value = interpretation.value
            score = interpretation.score
            
            return (
                f"Feature({self.index}, "
                f"{len(self.examples)} examples, "
                f"interpretation={value} ({score})"
            )

        return (
                f"Feature({self.index}, "
                f"{len(self.examples)} examples "
        )