from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
from torch import Tensor
from jaxtyping import Float, Int
import tqdm

from sae_lens import SAE, HookedSAETransformer


@dataclass
class FeatureActivations:
    
    activations: Float[Tensor, "samples seq features"]
    tokens: Int[Tensor, "samples seq"]
    is_reasoning: list[bool]
    sources: list[str]
    
    layer_index: int
    model_name: str
    sae_name: str
    
    def save(self, path: Path):
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        torch.save({
            "activations": self.activations,
            "tokens": self.tokens,
            "is_reasoning": self.is_reasoning,
            "sources": self.sources,
            "layer_index": self.layer_index,
            "model_name": self.model_name,
            "sae_name": self.sae_name,
        }, path)
    
    @classmethod
    def load(cls, path: Path) -> "FeatureActivations":
        data = torch.load(path, map_location="cpu")
        return cls(
            activations=data["activations"],
            tokens=data["tokens"],
            is_reasoning=data["is_reasoning"],
            sources=data["sources"],
            layer_index=data["layer_index"],
            model_name=data["model_name"],
            sae_name=data["sae_name"],
        )
    
    @property
    def n_samples(self) -> int:
        return self.activations.shape[0]
    
    @property
    def seq_len(self) -> int:
        return self.activations.shape[1]
    
    @property
    def n_features(self) -> int:
        return self.activations.shape[2]
    
    def get_reasoning_mask(self) -> Tensor:
        return torch.tensor(self.is_reasoning, dtype=torch.bool)
    
    def get_max_activations(self) -> Float[Tensor, "samples features"]:
        return self.activations.max(dim=1).values
    
    def get_mean_activations(self) -> Float[Tensor, "samples features"]:
        return self.activations.mean(dim=1)


class FeatureCollector:
    
    def __init__(
        self,
        model_name: str = "google/gemma-2-2b",
        sae_name: str = "gemma-scope-2b-pt-res-canonical",
        sae_id_format: str = "layer_{layer}/width_16k/canonical",
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
    ):
        self.model_name = model_name
        self.sae_name = sae_name
        self.sae_id_format = sae_id_format
        self.device = device
        self.dtype = dtype
        
        self.model: Optional[HookedSAETransformer] = None
        self.sae: Optional[SAE] = None
        self.current_layer: Optional[int] = None
    
    def load_model(self):
        if self.model is None:
            print(f"Loading model: {self.model_name}")
            self.model = HookedSAETransformer.from_pretrained_no_processing(
                self.model_name,
                device=self.device,
                dtype=self.dtype,
            )
    
    def load_sae(self, layer_index: int):
        if self.current_layer != layer_index:
            print(f"Loading SAE for layer {layer_index}")
            sae_id = self.sae_id_format.format(layer=layer_index)
            self.sae = SAE.from_pretrained(
                release=self.sae_name,
                sae_id=sae_id,
                device=self.device,
            )
            if isinstance(self.sae, tuple):
                self.sae = self.sae[0]
            self.current_layer = layer_index
    
    def tokenize_texts(
        self,
        texts: list[str],
        max_length: int = 512,
    ) -> Int[Tensor, "batch seq"]:
        self.load_model()
        
        encoded = self.model.tokenizer(
            texts,
            max_length=max_length,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        return encoded["input_ids"]
    
    def collect_activations(
        self,
        tokens: Int[Tensor, "batch seq"],
        layer_index: int,
        is_reasoning: list[bool],
        sources: list[str],
        batch_size: int = 8,
        max_features: Optional[int] = None,
    ) -> FeatureActivations:
        self.load_model()
        self.load_sae(layer_index)
        
        all_activations = []
        n_samples = tokens.shape[0]
        
        try:
            hook_name = self.sae.cfg.metadata.hook_name
        except:
            hook_name = self.sae.cfg.hook_name
        
        with tqdm.tqdm(total=n_samples, desc="Collecting activations") as pbar:
            for i in range(0, n_samples, batch_size):
                batch = tokens[i:i + batch_size].to(self.device)
                
                with torch.no_grad():
                    _, cache = self.model.run_with_cache_with_saes(
                        batch,
                        saes=[self.sae],
                        use_error_term=True,
                    )
                
                acts = cache[f"{hook_name}.hook_sae_acts_post"]
                
                if max_features is not None:
                    acts = acts[:, :, :max_features]
                
                all_activations.append(acts.cpu().float())
                
                pbar.update(len(batch))
                del cache, batch
                torch.cuda.empty_cache()
        
        activations = torch.cat(all_activations, dim=0)
        
        return FeatureActivations(
            activations=activations,
            tokens=tokens.cpu(),
            is_reasoning=is_reasoning,
            sources=sources,
            layer_index=layer_index,
            model_name=self.model_name,
            sae_name=self.sae_name,
        )
    
    def collect_from_datasets(
        self,
        reasoning_dataset,
        nonreasoning_dataset,
        layer_index: int,
        max_length: int = 512,
        batch_size: int = 8,
        max_features: Optional[int] = None,
    ) -> FeatureActivations:
        reasoning_dataset.load()
        nonreasoning_dataset.load()
        
        all_texts = []
        is_reasoning = []
        sources = []
        
        for sample in reasoning_dataset:
            all_texts.append(sample.text)
            is_reasoning.append(True)
            sources.append(sample.source)
        
        for sample in nonreasoning_dataset:
            all_texts.append(sample.text)
            is_reasoning.append(False)
            sources.append(sample.source)
        
        tokens = self.tokenize_texts(all_texts, max_length=max_length)
        
        return self.collect_activations(
            tokens=tokens,
            layer_index=layer_index,
            is_reasoning=is_reasoning,
            sources=sources,
            batch_size=batch_size,
            max_features=max_features,
        )
