from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
import math
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset
import numpy as np
from src.utils import nanmin, nanmax
import gc
from src.feature import Feature, FeatureStatistics, Example

# ---------------------- Configuration ----------------------
@dataclass
class CollectionConfig:
    
    model_name: str
    hook_point: str
    layer: int
    dict_size: int
    batch_size: int
    device: Optional[str] = 'cpu'
    feature_indices: Optional[List[int]] = None

    # ---------- Dataset -----------
    dataset_path: str = "HuggingFaceFW/fineweb"
    dataset_name: str = "default"
    dataset_split: str = "train"
    streaming: bool = True



    # ----------- Buffer -----------
    pos_buffer_size: int = 128
    neg_buffer_size: int = 64
    seq_len: int = 256
    pack_size: int = 64

    # --------- Histogram -----------
    hist_bins: int = 20
    hist_min: float = 1e-1
    hist_max: float = 1e2
    pos_bins: int = 20

    # ---------- Early Exit -----------
    filled_percent: float = 100.0
    avg_fill_rate: float = 1.0
    min_fill_rate: float = 1.0
    exit_strategy: str = "tokens"

    # ---------- Various -----------
    move_pack_to_cpu: bool = False

# ---------------------- Buffer Structures ----------------------
@dataclass
class BufferData:
    token_ids: torch.Tensor
    counts: torch.Tensor
    activations: Optional[torch.Tensor] = None

# ---------------------- Main Collector Class ----------------------
class LatentCollector:
    
    def __init__(self, model, sae, config: CollectionConfig):
        self.config = config
        self.device = config.device
        self.buffer_device = "cpu" if config.move_pack_to_cpu else config.device
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.model = model.to(self.device).eval()
        self.sae = sae.to(self.device).eval() if sae else None

        self._validate_config()
        self.feature_indices = self._init_feature_indices()
        self.num_features = len(self.feature_indices)

        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
            
        self.pad_token_id = self.tokenizer.pad_token_id
            
        self.special_ids = torch.tensor(
            self.tokenizer.all_special_ids,
            dtype=torch.long
        )

        self.dataloader = self._init_dataloader()
        self._init_buffers()
        self._init_histograms()
        self._init_running_stats()

        self.buffer_metrics = {
            'filled_percentage': [],
            'avg_fill_rate': [],
            'min_fill_rate': [],
            'max_fill_rate': []
        }

        self.total_valid_tokens = 0
        self.total_sequences = 0

    def _validate_config(self):
        if self.config.layer < 0:
            raise ValueError("Layer index must be non-negative")
        if self.config.dict_size <= 0:
            raise ValueError("Dictionary size must be positive")
        if self.config.batch_size <= 0:
            raise ValueError("Batch size must be positive")

    def _init_feature_indices(self) -> torch.Tensor:
        if self.config.feature_indices is None:
            indices = torch.arange(self.config.dict_size, device=self.device)
        else:
            indices = torch.tensor(self.config.feature_indices, device=self.device)
        
        if (indices >= self.config.dict_size).any():
            raise ValueError("Feature indices exceed dictionary size")
        return indices

    def _init_dataloader(self) -> DataLoader:
        cfg = self.config
        dataset = load_dataset(cfg.dataset_path, name=cfg.dataset_name, split=cfg.dataset_split, streaming=cfg.streaming)
        return DataLoader(dataset, batch_size=self.config.batch_size)




    def _init_buffers(self):
        cfg = self.config
        
        self.positive_buffer = BufferData(
            token_ids=torch.full(
                (self.num_features, cfg.pos_buffer_size, cfg.seq_len),
                self.pad_token_id, dtype=torch.long, device=self.buffer_device
            ),
            counts=torch.zeros(self.num_features, dtype=torch.long, device=self.buffer_device),
            activations=torch.zeros(
                (self.num_features, cfg.pos_buffer_size, cfg.seq_len), 
                device=self.buffer_device
            )
        )
        self.negative_buffer = BufferData(
            token_ids=torch.full(
                (self.num_features, cfg.neg_buffer_size, cfg.seq_len),
                self.pad_token_id, dtype=torch.long, device=self.buffer_device
            ),
            counts=torch.zeros(self.num_features, dtype=torch.long, device=self.buffer_device)
        )
        self.storage = {
            'activations': [],
            'input_ids': [],
        }

    def _init_histograms(self):
        cfg = self.config
        if self.device == 'mps':
            bins = torch.tensor(np.logspace(
                math.log10(cfg.hist_min),
                math.log10(cfg.hist_max),
                num=cfg.hist_bins + 1
            ), device=self.buffer_device)
        else:
            bins = torch.logspace(
                math.log10(cfg.hist_min),
                math.log10(cfg.hist_max),
                steps=cfg.hist_bins, device=self.buffer_device
            )
        self.hist_bins = bins
        self.hist_counts = torch.zeros(
            (self.num_features, cfg.hist_bins), device=self.buffer_device
        )

    def _init_running_stats(self):
        self.min_vals = torch.full((self.num_features,), torch.inf, device=self.device)
        self.max_vals = torch.full((self.num_features,), -torch.inf, device=self.device)
        self.sum_vals = torch.zeros(self.num_features, device=self.device)
        self.count_vals = torch.zeros(self.num_features, device=self.device)

    # ---------------------- Main Collection Loop ----------------------
    @torch.no_grad()
    def collect(self, max_tokens: int = 1e6) -> Dict[int, Feature]:
        self.processed_tokens = 0
        self.max_tokens = max_tokens
        progress_bar = tqdm(total=max_tokens, desc="Collecting activations")

        should_exit = False

        self._collect_garbage()
        for batch in self.dataloader:
            input_ids, attention_mask = self._prepare_batch(batch)
            hidden_state = self._get_hidden_state(input_ids, attention_mask)
            activations = self._get_activations(hidden_state)
            
            self._mask_special_tokens(activations, input_ids)
            self._update_running_stats(activations)
            self._store_batch_data(input_ids, activations)
            
            if self._should_process_pack():
                self._process_pack()
                should_exit = self._check_early_exit(progress_bar)
                self._collect_garbage()
                
            self.processed_tokens += input_ids.numel()
            progress_bar.update(input_ids.numel())
            if self.processed_tokens >= self.max_tokens:
                progress_bar.total = self.processed_tokens
                break
            elif should_exit:
                break

        progress_bar.close()
        result = self._compile_statistics()
        self._collect_garbage()
        self._reset_storage()
        return result

    # ---------------------- Helper Methods ----------------------
    def _collect_garbage(self):
        if self.device == 'mps':
            torch.mps.empty_cache()
        elif self.device == 'cuda':
            torch.cuda.empty_cache()
        gc.collect()
    
    def _prepare_batch(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
        inputs = self.tokenizer(
            batch['text'],
            truncation=True,
            padding='max_length',
            max_length=self.config.seq_len,
            return_tensors='pt'
        )
        return inputs['input_ids'].to(self.device), inputs['attention_mask'].to(self.device)

    def _get_hidden_state(self, input_ids, attention_mask):
        _, cache = self.model.run_with_cache(
            input_ids,
            attention_mask=attention_mask,
            names_filter=[self.config.hook_point],
            stop_at_layer=self.config.layer + 1,
        )
        return cache[self.config.hook_point]

    def _get_activations(self, hidden_state):
        B, T, d = hidden_state.shape
        return self.sae.encode(hidden_state.reshape(B*T, -1))[:, self.feature_indices]

    def _mask_special_tokens(self, activations, input_ids):
        input_flat = input_ids.view(-1)
        special_mask = torch.isin(input_flat, self.special_ids.to(self.device))
        activations[special_mask] = torch.nan
        self.total_valid_tokens += (~special_mask).sum().item()
        self.total_sequences += input_ids.shape[0]

    def _update_running_stats(self, activations):
        mask = ~torch.isnan(activations) & (activations > 0)
        valid_acts = torch.where(mask, activations, torch.nan)
        
        # Update min
        current_min = nanmin(valid_acts, dim=0).values
        self.min_vals = torch.minimum(self.min_vals, torch.nan_to_num(current_min, nan=torch.inf))
        
        # Update max
        current_max = nanmax(valid_acts, dim=0).values
        self.max_vals = torch.maximum(self.max_vals, torch.nan_to_num(current_max, nan=-torch.inf))
        
        # Update sum and count
        self.sum_vals += torch.nansum(valid_acts, dim=0)
        self.count_vals += mask.sum(dim=0)

    def _store_batch_data(self, input_ids, activations):
        self.storage['activations'].append(activations)
        self.storage['input_ids'].append(input_ids.view(-1))

    def _should_process_pack(self):
        return len(self.storage['activations']) * self.config.batch_size >= self.config.pack_size

    def _process_pack(self):

        activations = torch.cat(self.storage['activations']).to(self.buffer_device)
        input_ids = torch.cat(self.storage['input_ids']).to(self.buffer_device)
            
        self._update_histograms(activations)
        self._update_buffers(activations, input_ids)
        self._reset_storage()

    def _check_early_exit(self, pbar) -> bool:
        """Check if collection should stop based on buffer fill metrics and strategy."""
        current_stats = self._get_buffer_stats()
        cfg = self.config
    
        # Extract metrics from current buffer state
        filled_pct = current_stats['filled_percentage']
        avg_fill = current_stats['avg_fill_rate']
        min_fill = current_stats['min_fill_rate']

        pbar.set_description(f"Percentage - {filled_pct:.2f}%, average - {avg_fill:.2f}, min - {min_fill:.2f}")
    
        # Map strategy to exit conditions
        strategy_checks = {
            "any": [
                filled_pct >= cfg.filled_percent,
                min_fill >= cfg.min_fill_rate
            ],
            "percentage": [filled_pct >= cfg.filled_percent],
            "avgrate": [avg_fill >= cfg.avg_fill_rate],
            "minrate": [min_fill >= cfg.min_fill_rate]
        }
    
        # Get appropriate checks for current strategy
        checks = strategy_checks.get(cfg.exit_strategy, [])
        
        # Return True if any required condition is met
        return any(checks)

    def _get_buffer_stats(self):
        """Calculate comprehensive buffer statistics"""
        counts = self.positive_buffer.counts.float()
        buffer_size = self.config.pos_buffer_size
        
        filled_mask = counts >= buffer_size
        stats = {
            'filled_percentage': (filled_mask.sum() / self.num_features * 100).item(),
            'avg_fill_rate': (counts / buffer_size).mean().item(),
            'min_fill_rate': (counts / buffer_size).min().item(),
            'median_fill_rate': (counts / buffer_size).median().item()
        }
        return stats

    # ---------------------- Buffer Management ----------------------
    def _update_buffers(self, activations, input_ids):
        B, T = self.config.pack_size, self.config.seq_len
        input_ids = input_ids.view(B, T)
        activations = activations.view(B, T, -1)
        
        special_mask = torch.isin(input_ids, self.special_ids.to(self.buffer_device))
        valid_mask = ~special_mask & (input_ids != self.pad_token_id)

        counts = self.positive_buffer.counts.float()
        buffer_size = self.config.pos_buffer_size
        filled = torch.where(counts >= buffer_size)[0].tolist()
        not_filled = torch.where(counts < buffer_size)[0].tolist()

        random_threshold = min(1, self.processed_tokens / self.max_tokens)
        to_iterate = not_filled + [
            f for f in filled if torch.rand(1).item() > random_threshold
        ]
        
        for feat_idx in to_iterate:
            feat_acts = activations[..., feat_idx]
            pos_mask = (feat_acts > 0) & valid_mask
            seq_has_act = pos_mask.any(dim=1)
            
            if seq_has_act.any():
                self._update_reservoir(
                    'positive', feat_idx, 
                    input_ids[seq_has_act], 
                    feat_acts[seq_has_act], 
                    pos_mask[seq_has_act]
                )
            
            neg_mask = ~seq_has_act & valid_mask.any(dim=1)
            if neg_mask.any():
                self._update_reservoir(
                    'negative', feat_idx, 
                    input_ids[neg_mask], 
                    None, None
                )

    def _update_histograms(self, activations):
        mask = (activations > 0) & (~torch.isnan(activations))  # Updated
        activations = torch.where(mask, activations, torch.nan)
        bin_indices = torch.bucketize(activations, self.hist_bins)
        bin_indices = torch.clamp(bin_indices, 0, self.config.hist_bins - 1)
        for b in range(self.config.hist_bins):
            in_bin = (bin_indices == b) & mask
            self.hist_counts[:, b] += in_bin.sum(dim=0)

    def _update_reservoir(self, buffer_type, feat_idx, sequences, activations, mask):
        buffer = self.positive_buffer if buffer_type == 'positive' else self.negative_buffer
        current_count = buffer.counts[feat_idx].item()
        buffer_size = buffer.token_ids.shape[1]
        
        # Calculate remaining capacity
        remaining = max(0, buffer_size - current_count)
        
        # Split into append and replace candidates
        append_seq = sequences[:remaining]
        replace_seq = sequences[remaining:]
        
        # Append new sequences
        if append_seq.size(0) > 0:
            append_end = current_count + append_seq.size(0)
            buffer.token_ids[feat_idx, current_count:append_end] = append_seq
            if buffer_type == 'positive':
                buffer.activations[feat_idx, current_count:append_end] = activations[:len(append_seq)]
            buffer.counts[feat_idx] += append_seq.size(0)
        
        # Replace existing sequences
        if replace_seq.size(0) > 0:
            # Ensure we don't exceed buffer size
            n_replace = min(replace_seq.size(0), buffer_size)
            replace_seq = replace_seq[:n_replace]
            
            # Generate unique random indices
            replace_indices = torch.randperm(buffer_size, device=self.device)[:n_replace]
            
            # Update buffer
            buffer.token_ids[feat_idx, replace_indices] = replace_seq
            if buffer_type == 'positive':
                buffer.activations[feat_idx, replace_indices] = activations[len(append_seq):len(append_seq) + n_replace]

    def _reset_storage(self):
        self.storage = {'activations': [], 'input_ids': []}

    # ---------------------- Statistics Compilation ----------------------
    def _compile_statistics(self) -> Dict[int, Feature]:
        features = {}
        hist_edges = self.hist_bins.cpu().tolist()
        total_val_tokens = max(self.total_valid_tokens, 1)
        total_seqs = max(self.total_sequences, 1)
    
        pbar = tqdm(total=len(self.feature_indices), desc="Compiling results")
        
        for buf_idx, orig_idx in enumerate(self.feature_indices.cpu().tolist()):
            pos_examples, metrics = self._process_positive_examples(buf_idx)
            neg_examples = self._process_negative_examples(buf_idx)
            stats = self._compute_feature_stats(buf_idx, metrics, total_val_tokens, total_seqs)
            feature = Feature(
                index=orig_idx,
                examples=pos_examples + neg_examples,
                statistics=stats
            )
            features[orig_idx] = feature
            pbar.update(1)
    
        pbar.close()
        return features


        for seq_idx in range(valid_count):
            raw_tokens = token_buffer[seq_idx]
            activations = act_buffer[seq_idx]

            # Filter special tokens and pad
            valid_mask = ~torch.isin(raw_tokens, self.special_ids)
            valid_tokens = raw_tokens[valid_mask].tolist()
            
            # Get activation positions in valid_tokens space
            act_mask = activations[valid_mask] > 0
            act_positions = torch.where(act_mask)[0].tolist()

            examples.append(Example(
                full_context=valid_tokens,
                activation_positions=act_positions,
                activation_values=activations[act_mask].tolist(),
                is_positive=True
            ))
    

    def _process_positive_examples(self, buf_idx) -> Tuple[List[Example], dict]:
        examples = []
        metrics = {
            'position_hist': np.zeros(self.config.pos_bins),
            'gini_values': [],
            'total_activations': 0,
            'active_sequences': 0,
            'multitoken_count': 0
        }
        
        token_buffer = self.positive_buffer.token_ids[buf_idx].cpu()
        act_buffer = self.positive_buffer.activations[buf_idx].cpu()
        valid_count = self.positive_buffer.counts[buf_idx].item()
    
        for seq_idx in range(valid_count):
            raw_tokens = token_buffer[seq_idx]
            activations = act_buffer[seq_idx]

            # Filter special tokens and pad
            valid_mask = ~torch.isin(raw_tokens, self.special_ids)
            valid_tokens = raw_tokens[valid_mask]
            if valid_tokens.numel() == 0:
                continue
    
            # Get activations within valid tokens
            act_mask = activations[valid_mask] > 0
            act_positions = torch.where(act_mask)[0].numpy()
            if act_positions.size == 0:
                continue
    
            # Update metrics
            metrics['active_sequences'] += 1
            metrics['total_activations'] += len(act_positions)
            if len(act_positions) > 1:
                metrics['multitoken_count'] += 1
    
            # Calculate positional distribution
            seq_length = len(valid_tokens)
            normalized_pos = act_positions / (seq_length - 1)
            pos_hist, _ = np.histogram(normalized_pos, 
                                     bins=self.config.pos_bins,
                                     range=(0, 1))
            metrics['position_hist'] += pos_hist
    
            # Calculate Gini coefficient
            sorted_pos = np.sort(act_positions)
            n = len(sorted_pos)
            gini_numerator = np.sum((2 * np.arange(1, n+1) - n - 1) * sorted_pos)
            gini = gini_numerator / (n * np.sum(sorted_pos)) if n > 0 else 0
            metrics['gini_values'].append(gini if not np.isnan(gini) else 0)
    
            # Create example with valid tokens only
            examples.append(Example(
                context=valid_tokens.tolist(),
                activation_positions=act_positions.tolist(),
                activation_values=activations[valid_mask][act_mask].tolist(),
                is_positive=True
            ))
    
        return examples, metrics
        

    def _process_negative_examples(self, buf_idx) -> List[Example]:
        examples = []
        token_buffer = self.negative_buffer.token_ids[buf_idx].cpu()
        valid_count = self.negative_buffer.counts[buf_idx].item()
        # Similar filtering for negative examples
        for seq_idx in range(valid_count):
            raw_tokens = token_buffer[seq_idx]
            valid_mask = ~torch.isin(raw_tokens, self.special_ids)
            valid_tokens = raw_tokens[valid_mask].tolist()
            
            examples.append(Example(
                context=valid_tokens,
                activation_positions=[],
                activation_values=[],
                is_positive=False
            ))
        return examples


    def _get_valid_token_mask(self, tokens: torch.Tensor) -> torch.Tensor:
        """Create mask filtering out special and padding tokens"""
        return ~torch.isin(tokens, self.special_ids)


    def _compute_feature_stats(self, buf_idx, metrics, total_tokens, total_seqs) -> FeatureStatistics:
        hist_cfg = self.config
        feature_min = self.min_vals[buf_idx].item()
        feature_max = self.max_vals[buf_idx].item()
        feature_mean = (self.sum_vals[buf_idx] / self.count_vals[buf_idx]).item() if self.count_vals[buf_idx] > 0 else 0.0

        # Normalize positional histogram
        pos_hist_total = max(metrics['position_hist'].sum(), 1)
        positional_hist = (metrics['position_hist'] / pos_hist_total).tolist()

        # Calculate positional statistics
        bin_width = 1.0 / hist_cfg.pos_bins
        midpoints = [(i + 0.5) * bin_width for i in range(hist_cfg.pos_bins)]
        pos_mean = np.dot(positional_hist, midpoints)
        pos_variance = np.sum([p * (m - pos_mean)**2 
                            for p, m in zip(positional_hist, midpoints)])
        pos_std = np.sqrt(pos_variance)

        return FeatureStatistics(
            min=feature_min,
            max=feature_max,
            mean=feature_mean,
            histogram_counts=self.hist_counts[buf_idx].cpu().tolist(),
            histogram_edges=self.hist_bins.cpu().tolist(),
            positional_hist=positional_hist,
            gini_coeff=np.mean(metrics['gini_values']) if metrics['gini_values'] else 0.0,
            intra_seq_density=metrics['total_activations'] / max(metrics['active_sequences'], 1),
            activation_freq=metrics['total_activations'] / total_tokens,
            sequence_penetration=metrics['active_sequences'] / total_seqs,
            pos_mean=pos_mean,
            pos_std=pos_std,
            multitoken_ratio=metrics['multitoken_count'] / max(metrics['active_sequences'], 1),
            frequency=self.count_vals[buf_idx].item() / total_tokens
        )