# Add path for sae_training to be accessible directly
import os
import sys

import torch

sys.path.append(
    os.path.join(os.path.dirname(__file__), "../sae_bench_utils/gpt2_tc_utils")
)

from ff_kv_sae.sae_bench_utils.gpt2_tc_utils.sae_training.sparse_autoencoder import (
    SparseAutoencoder,
)

from .base_sae import BaseSAE


class GPT2Transcoder(BaseSAE):
    def __init__(
        self,
        hook_layer: int,
        device: torch.device,
        dtype: torch.dtype,
        hook_points: dict | None = None,
    ):
        # Use provided hook_points or default
        if hook_points is None:
            hook_points = {
                "input": f"blocks.{hook_layer}.ln2.hook_normalized",
                "features": f"blocks.{hook_layer}.mlp.hook_post",
                "output": f"blocks.{hook_layer}.hook_mlp_out",
            }

        self.input_hook = hook_points["input"]
        self.features_hook = hook_points["features"]
        self.output_hook = hook_points["output"]

        # get weights and config
        gpt2_tc = SparseAutoencoder.load_from_pretrained(
            path=f"gpt-2-small-transcoders/final_sparse_autoencoder_gpt2-small_blocks.{hook_layer}.ln2.hook_normalized_24576.pt",
        )
        d_in = gpt2_tc.cfg.d_in
        d_sae = gpt2_tc.cfg.d_sae
        model_name = "gpt2"

        # Initialize BaseSAE first
        super().__init__(
            d_in=d_in,
            d_sae=d_sae,
            model_name=model_name,
            hook_layer=hook_layer,
            device=device,
            dtype=dtype,
            hook_name=self.input_hook,  # Input hook
        )

        # Assign parameters after super().__init__() is called
        self.d_in = d_in
        self.d_sae = d_sae
        self.model_name = model_name
        self.W_enc = gpt2_tc.W_enc
        self.W_dec = gpt2_tc.W_dec
        self.b_enc = gpt2_tc.b_enc
        self.b_dec = gpt2_tc.b_dec_out

    def encode(self, x: torch.Tensor):
        pre_acts = (x - self.b_dec) @ self.W_enc + self.b_enc
        acts = torch.relu(pre_acts)
        return acts

    def decode(self, feature_acts: torch.Tensor):
        return (feature_acts @ self.W_dec) + self.b_dec

    def forward(self, x: torch.Tensor):
        x = self.encode(x)
        recon = self.decode(x)
        return recon
