import torch
import torch.nn as nn

class FeatureAttackWrapper(nn.Module):
    def __init__(
        self,
        replacement_model,
        target_features_dict,
        protected_features_dict=None,
        aggregation: str = "max",
        topk_frac: float = 0.2,
        strict_hooks: bool = True,
    ):
        super().__init__()
        self.model = replacement_model
        self.target_features = target_features_dict  # {layer: [idx1, idx2...]}
        self.protected_features = protected_features_dict or {}
        self.aggregation = aggregation
        self.topk_frac = float(topk_frac)
        self.strict_hooks = bool(strict_hooks)
        self.hf_model = self._find_base_model(replacement_model)
        self._debug_once = False
        self._layer_hook_name = {}
        
        # 冻结参数，只通过 Input 更新
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

    @property
    def device(self):
        if hasattr(self.model, "device"):
            return self.model.device
        if hasattr(self.model, "cfg") and hasattr(self.model.cfg, "device"):
            return self.model.cfg.device
        try:
            return next(self.model.parameters()).device
        except StopIteration:
            return torch.device("cpu")

    @property
    def dtype(self):
        if hasattr(self.model, "dtype"):
            return self.model.dtype
        if hasattr(self.model, "cfg") and hasattr(self.model.cfg, "dtype"):
            return self.model.cfg.dtype
        try:
            return next(self.model.parameters()).dtype
        except StopIteration:
            return torch.float32

    def _find_base_model(self, module):
        # 递归找到原始 HF 模型以获取 Embedding Layer
        if hasattr(module, "get_input_embeddings"): return module
        for attr in ["model", "wrapped_model", "transformer", "gpt"]:
            if hasattr(module, attr):
                found = self._find_base_model(getattr(module, attr))
                if found: return found
        return module

    def get_input_embeddings(self):
        if hasattr(self.hf_model, "get_input_embeddings"):
            return self.hf_model.get_input_embeddings()
        if hasattr(self.model, "W_E"):
            class InputEmbeddingAdapter(nn.Module):
                def __init__(self, weight: torch.Tensor):
                    super().__init__()
                    self.weight = weight
                    self.num_embeddings = weight.shape[0]

                def forward(self, input_ids: torch.Tensor):
                    return self.weight[input_ids.long()]

            return InputEmbeddingAdapter(self.model.W_E)
        raise AttributeError("Model does not expose input embeddings")

    def _aggregate(self, selected: torch.Tensor) -> torch.Tensor:
        if self.aggregation == "mean":
            return selected.mean(dim=1)
        if self.aggregation == "topk_mean":
            k = max(1, int(round(selected.shape[1] * self.topk_frac)))
            vals = selected.topk(k, dim=1).values
            return vals.mean(dim=1)
        return selected.max(dim=1).values

    def _choose_hook(self, layer: int) -> str | None:
        name_in = f"blocks.{layer}.{self.model.feature_input_hook}"
        name_out_grad = None
        if hasattr(self.model, "feature_output_hook"):
            name_out_grad = f"blocks.{layer}.{self.model.feature_output_hook}"
        name_out = None
        if hasattr(self.model, "original_feature_output_hook"):
            name_out = f"blocks.{layer}.{self.model.original_feature_output_hook}"

        for candidate in [name_in, name_out_grad, name_out]:
            if candidate and hasattr(self.model, "hook_dict") and candidate in self.model.hook_dict:
                return candidate
        return None

    def _ensure_hook_map(self, layers):
        missing = []
        for layer in layers:
            if layer in self._layer_hook_name:
                continue
            chosen = self._choose_hook(layer)
            self._layer_hook_name[layer] = chosen
            if chosen is None:
                missing.append(layer)
        if missing and self.strict_hooks:
            raise RuntimeError(f"Missing hookpoints for layers: {sorted(missing)}")

    def forward(self, inputs_embeds, **kwargs):
        """
        核心魔法：Hook 住模型，只返回特定 Feature 的激活值
        """
        layer_acts = {}
        if not self._debug_once:
            try:
                all_hooks = list(self.model.hook_dict.keys())
                print(f"[WRAP DEBUG] available hookpoints (sample): {all_hooks[:8]} ... total={len(all_hooks)}")
            except Exception as e:
                print(f"[WRAP DEBUG] unable to list hook_dict: {e}")

        layers = sorted(self.target_features.keys())
        self._ensure_hook_map(layers)

        def feature_hook_fn(activations, hook):
            layer_idx = hook.layer()
            if layer_idx in self.target_features:
                if not self._debug_once:
                    print(f"[WRAP DEBUG] layer={layer_idx} activations shape={tuple(activations.shape)} req_grad={activations.requires_grad} dtype={activations.dtype} device={activations.device}")
                # 1. 计算 Transcoder Features (Apply ReLU)
                encoded = self.model.transcoders.encode_layer(activations, layer_idx, apply_activation_function=True)
                if not self._debug_once:
                    print(f"[WRAP DEBUG] encoded shape={tuple(encoded.shape)} req_grad={encoded.requires_grad} dtype={encoded.dtype}")
                
                # 2. 选出目标 Feature Indices
                target_indices = self.target_features[layer_idx]
                idx_tensor = torch.tensor(target_indices, device=encoded.device, dtype=torch.long)
                selected = encoded.index_select(dim=-1, index=idx_tensor)
                if not self._debug_once:
                    print(f"[WRAP DEBUG] selected shape={tuple(selected.shape)} req_grad={selected.requires_grad}")
                
                # 3. 取序列最大值 (只要文中某处激活即可)
                # 输出 shape: [Batch, Len(Indices)]
                layer_acts[layer_idx] = self._aggregate(selected)
                if not self._debug_once:
                    print(f"[WRAP DEBUG] layer_acts[{layer_idx}] shape={tuple(layer_acts[layer_idx].shape)} req_grad={layer_acts[layer_idx].requires_grad}")
                    self._debug_once = True
            return activations

        # 构造 Hook 列表
        hooks = []
        for layer in layers:
            chosen = self._layer_hook_name.get(layer)
            if not self._debug_once:
                print(f"[WRAP DEBUG] layer={layer} chosen_hook={chosen}")
            if chosen:
                hooks.append((chosen, feature_hook_fn))

        # 运行模型
        # 覆盖嵌入：用 hook_embed 替换模型的词嵌入为传入的 inputs_embeds
        def override_embed(x, hook):
            return inputs_embeds

        # 构造占位 tokens（长度与 inputs_embeds 对齐），实际嵌入由上面的钩子提供
        B, T, _ = inputs_embeds.shape
        dummy_tokens = torch.zeros(B, T, dtype=torch.long, device=inputs_embeds.device)

        with self.model.hooks(fwd_hooks=hooks + [("hook_embed", override_embed)]):
            self.model(dummy_tokens)

        # 拼接所有层的结果
        outputs = []
        for layer in sorted(self.target_features.keys()):
            if layer in layer_acts:
                outputs.append(layer_acts[layer])
        
        if not outputs:
            if self.strict_hooks:
                raise RuntimeError("No feature activations captured; check hookpoints and target features")
            return torch.zeros(inputs_embeds.shape[0], 1, device=inputs_embeds.device, requires_grad=True)
            
        return torch.cat(outputs, dim=1) # [Batch, Total_Target_Features]


    def compute_next_token_logits_from_embeds(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
        """Compute next-token logits for the given embeddings.

        This is used for semantic-preserving regularizers (e.g., KL to the base prompt)
        during GCG optimization. It must be differentiable w.r.t. inputs_embeds.
        """
        def override_embed(x, hook):
            return inputs_embeds

        B, T, _ = inputs_embeds.shape
        dummy_tokens = torch.zeros(B, T, dtype=torch.long, device=inputs_embeds.device)
        with self.model.hooks(fwd_hooks=[("hook_embed", override_embed)]):
            logits = self.model(dummy_tokens)
        return logits

    def forward_grouped(self, inputs_embeds):
        layer_sel = {}
        layer_prot = {}

        union_layers = set(self.target_features.keys()) | set(self.protected_features.keys())

        self._ensure_hook_map(sorted(union_layers))

        def feature_hook_fn(activations, hook):
            layer_idx = hook.layer()
            if (layer_idx in union_layers):
                encoded = self.model.transcoders.encode_layer(activations, layer_idx, apply_activation_function=True)
                if layer_idx in self.target_features:
                    idx_tensor = torch.tensor(self.target_features[layer_idx], device=encoded.device, dtype=torch.long)
                    selected = encoded.index_select(dim=-1, index=idx_tensor)
                    layer_sel[layer_idx] = self._aggregate(selected)
                if layer_idx in self.protected_features:
                    idx_tensor2 = torch.tensor(self.protected_features[layer_idx], device=encoded.device, dtype=torch.long)
                    selected2 = encoded.index_select(dim=-1, index=idx_tensor2)
                    layer_prot[layer_idx] = self._aggregate(selected2)
            return activations

        hooks = []
        for layer in sorted(union_layers):
            chosen = self._layer_hook_name.get(layer)
            if chosen:
                hooks.append((chosen, feature_hook_fn))

        def override_embed(x, hook):
            return inputs_embeds

        B, T, _ = inputs_embeds.shape
        dummy_tokens = torch.zeros(B, T, dtype=torch.long, device=inputs_embeds.device)

        with self.model.hooks(fwd_hooks=hooks + [("hook_embed", override_embed)]):
            self.model(dummy_tokens)

        sel_out = None
        prot_out = None
        if layer_sel:
            sel_out = torch.cat([layer_sel[l] for l in sorted(layer_sel.keys())], dim=1)
        if layer_prot:
            prot_out = torch.cat([layer_prot[l] for l in sorted(layer_prot.keys())], dim=1)
        return sel_out, prot_out
