from dataclasses import dataclass
from typing import Callable
import torch
from torch import Tensor
from jaxtyping import Float

from sae_lens import SAE, HookedSAETransformer


@dataclass
class SteeringConfig:
    
    feature_index: int
    gamma: float = 1.0
    max_feature_activation: float = 1.0
    layer_index: int = 8
    start_position: int = 0


class FeatureSteerer:
    
    def __init__(
        self,
        model: HookedSAETransformer,
        sae: SAE,
    ):
        self.model = model
        self.sae = sae
        try:
            self.hook_name = sae.cfg.metadata.hook_name
        except:
            self.hook_name = sae.cfg.hook_name
        self._hooks_active = False
        
        self.W_dec = sae.W_dec.detach()
    
    def _create_steering_hook(
        self,
        config: SteeringConfig,
    ) -> Callable:
        
        decoder_direction = self.W_dec[config.feature_index]
        steering_vector = config.gamma * config.max_feature_activation * decoder_direction
        
        def steering_hook(
            activations: Float[Tensor, "batch seq d_model"],
            hook,
        ) -> Float[Tensor, "batch seq d_model"]:
            steering_vec = steering_vector.to(dtype=activations.dtype, device=activations.device)
            
            if config.start_position > 0:
                result = activations.clone()
                result[:, config.start_position:] = (
                    activations[:, config.start_position:] + steering_vec
                )
                return result
            else:
                return activations + steering_vec
        
        return steering_hook
    
    def _register_hook(self, config: SteeringConfig):
        hook_fn = self._create_steering_hook(config)
        self.model.add_hook(self.hook_name, hook_fn)
        self._hooks_active = True
    
    def _clear_hooks(self):
        if self._hooks_active:
            self.model.reset_hooks(clear_contexts=True, including_permanent=False)
            self._hooks_active = False
    
    def generate_with_steering(
        self,
        prompt: str,
        config: SteeringConfig,
        max_new_tokens: int = 256,
        temperature: float = 0.7,
        top_p: float = 0.9,
        do_sample: bool = True,
        apply_chat_template: bool = True,
        **generate_kwargs,
    ) -> str:
        self._clear_hooks()
        
        try:
            self._register_hook(config)
            
            try:
                device = self.model.cfg.device
            except:
                device = self.model.device
            
            if apply_chat_template and hasattr(self.model.tokenizer, 'apply_chat_template'):
                messages = [{"role": "user", "content": prompt}]
                formatted_prompt = self.model.tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                inputs = self.model.tokenizer(
                    formatted_prompt,
                    return_tensors="pt",
                    add_special_tokens=False,
                ).to(device)
            else:
                formatted_prompt = prompt
                inputs = self.model.tokenizer(
                    formatted_prompt,
                    return_tensors="pt",
                ).to(device)
            
            eos_token_ids = None
            if hasattr(self.model.tokenizer, 'eos_token_id') and self.model.tokenizer.eos_token_id is not None:
                end_of_turn_id = self.model.tokenizer.convert_tokens_to_ids('<end_of_turn>')
                if end_of_turn_id != self.model.tokenizer.unk_token_id:
                    eos_token_ids = [self.model.tokenizer.eos_token_id, end_of_turn_id]
                else:
                    eos_token_ids = self.model.tokenizer.eos_token_id
            
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs["input_ids"],
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=do_sample,
                    stop_at_eos=True,
                    eos_token_id=eos_token_ids,
                    verbose=False,
                    **generate_kwargs,
                )
            
            prompt_length = inputs["input_ids"].shape[1]
            generated = outputs[0, prompt_length:]
            text = self.model.tokenizer.decode(generated, skip_special_tokens=True)
            
            return text
            
        finally:
            self._clear_hooks()
    
    def generate_baseline(
        self,
        prompt: str,
        max_new_tokens: int = 256,
        temperature: float = 0.7,
        top_p: float = 0.9,
        do_sample: bool = True,
        apply_chat_template: bool = True,
        **generate_kwargs,
    ) -> str:
        self._clear_hooks()
        
        try:
            device = self.model.cfg.device
        except:
            device = self.model.device
        
        if apply_chat_template and hasattr(self.model.tokenizer, 'apply_chat_template'):
            messages = [{"role": "user", "content": prompt}]
            formatted_prompt = self.model.tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            inputs = self.model.tokenizer(
                formatted_prompt,
                return_tensors="pt",
                add_special_tokens=False,
            ).to(device)
        else:
            formatted_prompt = prompt
            inputs = self.model.tokenizer(
                formatted_prompt,
                return_tensors="pt",
            ).to(device)
        
        eos_token_ids = None
        if hasattr(self.model.tokenizer, 'eos_token_id') and self.model.tokenizer.eos_token_id is not None:
            end_of_turn_id = self.model.tokenizer.convert_tokens_to_ids('<end_of_turn>')
            if end_of_turn_id != self.model.tokenizer.unk_token_id:
                eos_token_ids = [self.model.tokenizer.eos_token_id, end_of_turn_id]
            else:
                eos_token_ids = self.model.tokenizer.eos_token_id
        
        with torch.no_grad():
            outputs = self.model.generate(
                inputs["input_ids"],
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=do_sample,
                stop_at_eos=True,
                eos_token_id=eos_token_ids,
                verbose=False,
                **generate_kwargs,
            )
        
        prompt_length = inputs["input_ids"].shape[1]
        generated = outputs[0, prompt_length:]
        return self.model.tokenizer.decode(generated, skip_special_tokens=True)
