import os

from iclr_project.helpers import constants
os.environ["HF_HOME"] = constants.MODELS_DIR

from nnsight import LanguageModel, list as nnlist
import torch
import numpy as np

from itertools import accumulate
import re
import math


class ObservableLanguageModel:
    def __init__(self, model, device="cuda", dtype=torch.bfloat16):
        self.device = device
        self.dtype = dtype
        self.model_name = model

        self._model = LanguageModel(self.model_name, device_map=self.device, torch_dtype=self.dtype)
        self.tokenizer = self._model.tokenizer
        self.d_model = self._attempt_to_infer_hidden_layer_dimensions()

    def _return_submodule(self, submodule_name):
        # Return the submodule from the model based on the provided hookpoint name
        submodules = submodule_name.split(".")
        module = self._model
        while submodules:
            module = getattr(module, submodules.pop(0))
        return module
    
    def _tokenize_prompt(self, prompt_text, add_generation_prompt=True, return_tensors="pt", apply_chat_template=True):

        if apply_chat_template:
            # Apply the chat template to the prompt text and tokenize it
            input_tokens = self.tokenizer.apply_chat_template(
                [
                    {"role": "user", "content": f"{prompt_text}"},
                ],
                add_generation_prompt=add_generation_prompt,
                return_tensors=return_tensors
            )
        else:
            # Tokenize the prompt text without applying the chat template
            input_tokens = self.tokenizer(
                prompt_text,
                return_tensors=return_tensors,
                add_special_tokens=True,
            )
            input_tokens = input_tokens["input_ids"]      
        return input_tokens
    
    def _attempt_to_infer_hidden_layer_dimensions(self):
        config = self._model.config
        if hasattr(config, "hidden_size"):
            return int(config.hidden_size)

        raise Exception(
            "Could not infer hidden number of layer dimensions from model config"
        )
    
    def trace(self, text, hookpoint=None, apply_chat_template=False):

        # Tokenize the input text
        # Check if text is a string
        if isinstance(text, str):
            input_tokens = self._tokenize_prompt(text, apply_chat_template=apply_chat_template)
        else:
            # Assume text is already tokenized
            input_tokens = text
        n_input_tokens = input_tokens.shape[-1]

        # Use the provided hookpoint
        module = self._return_submodule(hookpoint) if hookpoint else self._model

        with self._model.trace(input_tokens) as tracer:
            full_activations = module.output.save()

        return {
            "all": {
                "tokens": input_tokens.detach().squeeze(0).cpu(),
                "activations": full_activations.detach().squeeze(0),
                "text": self._model.tokenizer.batch_decode(input_tokens, skip_special_tokens=False)
            }
        }

    def generate(self, prompt_text, max_new_tokens=32, do_sample=False, hookpoint=None, apply_chat_template=True):

        # Tokenize the prompt text
        input_tokens = self._tokenize_prompt(prompt_text, apply_chat_template=apply_chat_template)
        n_input_tokens = input_tokens.shape[-1]

        # Use the provided hookpoint
        module = self._return_submodule(hookpoint) if hookpoint else self._model 

        with self._model.generate(input_tokens, max_new_tokens=max_new_tokens, do_sample=do_sample) as generator:
            all_token_ids = self._model.generator.output.save() # Generate the outputs
        with self._model.trace(all_token_ids) as tracer:
            full_activations = module.output.save() # Get the activations for the prompt and the generation

        # Split the activations and token ids into input and output parts
        input_activations = full_activations[:, :n_input_tokens, :]
        output_tokens = all_token_ids[:, n_input_tokens:]
        output_activations = full_activations[:, n_input_tokens:, :]

        # Decode the tokens to text
        input_texts = self._model.tokenizer.batch_decode(input_tokens, skip_special_tokens=False)
        output_texts = self._model.tokenizer.batch_decode(output_tokens, skip_special_tokens=False)
        output_texts_clean = self._model.tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
        full_texts = self._model.tokenizer.batch_decode(all_token_ids, skip_special_tokens=False)

        return {
            "input": {
                "tokens": input_tokens.detach().squeeze(0).cpu(),
                "activations": input_activations.detach().squeeze(0),
                "text": input_texts
            },
            "output": {
                "tokens": output_tokens.detach().squeeze(0).cpu(),
                "activations": output_activations.detach().squeeze(0),
                "text": output_texts,
                "text_clean": output_texts_clean
            },
            "all": {
                "tokens": all_token_ids.detach().squeeze(0).cpu(),
                "activations": full_activations.detach().squeeze(0),
                "text": full_texts
            }
        }


class SparseAutoEncoder(torch.nn.Module):
    def __init__(
        self,
        d_in: int,
        d_hidden: int,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.d_in = d_in
        self.d_hidden = d_hidden
        self.device = device
        self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
        self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
        self.dtype = dtype
        self.to(self.device, self.dtype)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode a batch of data using a linear, followed by a ReLU."""
        return torch.nn.functional.relu(self.encoder_linear(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """Decode a batch of data using a linear."""
        return self.decoder_linear(x)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """SAE forward pass. Returns the reconstruction and the encoded features."""
        f = self.encode(x)
        return self.decode(f), f


def load_sae(
    path: str,
    d_model: int,
    expansion_factor: int,
    device: torch.device = torch.device("cpu"),
):
    sae = SparseAutoEncoder(
        d_model,
        d_model * expansion_factor,
        device,
    )
    sae_dict = torch.load(
        path, weights_only=True, map_location=device
    )
    sae.load_state_dict(sae_dict)

    return sae

def get_nonzero_pos_vals(features: torch.Tensor):
    """
    Return (token_positions, concept_positions, concept_values) for a
    [seq_len, d_hidden] activation tensor.

    * token_positions   – 1-D np.int64,  len == seq_len + 1
    * concept_positions – 1-D np.int64,  concatenated concept ids
    * concept_values    – 1-D np.float32, same length as concept_positions
    """

    # ── 1  move to CPU
    features = features.detach().cpu()

    # ── 2  where are the non-zeros?
    idx  = features.nonzero(as_tuple=False)
    rows, cols = idx[:, 0], idx[:, 1]          # int64
    vals = features[rows, cols]                # (same dtype as features)

    # ── 3  stable sort by row
    order = rows.argsort(stable=True)
    rows, cols, vals = rows[order], cols[order], vals[order]

    # ── 4  how many activations per token?
    sizes = torch.bincount(rows, minlength=features.size(0)).tolist()

    cols_chunks = cols.split(sizes)
    vals_chunks = vals.split(sizes)

    # ── 5  build a list of [col, val] pairs per token
    stacks = []
    for c, v in zip(cols_chunks, vals_chunks):
        if c.numel() == 0:                             # empty token
            stacks.append(torch.tensor([[0.0, 0.0]],   # sentinel row
                                        dtype=torch.float32))
        else:
            # cast BOTH tensors to fp32 before stacking → no bfloat16 rounding
            stacks.append(torch.stack((c.float(), v.float()), dim=1))

    # ── 6  flatten
    token_positions = np.array(
        [0, *accumulate(len(t) for t in stacks)], dtype=np.int64
    )

    stacked = torch.vstack(stacks)                     # fp32
    concept_positions = stacked[:, 0].to(torch.int64).numpy()
    concept_values    = stacked[:, 1].numpy().astype(np.float32)

    # ── 7  sanity guard (remove if you like)
    assert concept_positions.max() < features.shape[1], (
        f"column index {concept_positions.max()} exceeds width {features.shape[1]}"
    )

    return token_positions, concept_positions, concept_values

def is_whitespace_token(piece):
    SPACE_MARKERS = {"Ġ", "Ċ", "▁"}
    # 1) literal whitespace (rare)  OR  2) consists solely of marker chars
    return piece.strip() == "" or all(ch in SPACE_MARKERS for ch in piece)

def split_tokens(tokens, tokenizer_vocab, min_tokens=10, max_sections=10):
    """
    Split a sequence of token IDs on the given delimiter IDs, merge short
    segments, cap the number of sections, and return both the segments and
    their (start, end) indices in the original list.

    Returns
    -------
    sections : list of list of int
        The token subsequences, each including any delimiter tokens that
        terminated it.
    indices  : list of (int, int)
        Matching (start, end) pairs (end is *exclusive*).
    """

    # Detect common whitespace *markers* used by popular tokenizers

    regex = re.compile(r'[.!?;:\n]+|\s{2,}')      # keep your punctuation rule

    matches = [
        tid
        for piece, tid in tokenizer_vocab
        if regex.fullmatch(piece) or is_whitespace_token(piece)
    ]

    split_token_ids = matches

    # ------------------------------------------------------------------ #
    # 1. Initial split — keep delimiter tokens in the segment.
    # ------------------------------------------------------------------ #
    sentences = []          # (start, end, tokens)
    start = 0
    current = []

    for i, t in enumerate(tokens):
        current.append(t)
        if t in split_token_ids:          # delimiter ⇒ cut *after* it
            sentences.append((start, i + 1, current))
            current = []
            start = i + 1

    if current:                           # tail without trailing delimiter
        sentences.append((start, len(tokens), current))

    # ------------------------------------------------------------------ #
    # 2. Merge segments that are too short.
    # ------------------------------------------------------------------ #
    combined = []
    buf_start = None
    buf_tokens = []

    for seg_start, seg_end, seg_tokens in sentences:
        candidate_tokens = buf_tokens + seg_tokens if buf_tokens else seg_tokens
        candidate_start  = buf_start if buf_tokens else seg_start

        if len(candidate_tokens) < min_tokens:
            buf_tokens = candidate_tokens
            buf_start  = candidate_start
        else:
            combined.append((candidate_start, seg_end, candidate_tokens))
            buf_tokens, buf_start = [], None

    if buf_tokens:  # leftover
        combined.append((buf_start, buf_start + len(buf_tokens), buf_tokens))

    # ------------------------------------------------------------------ #
    # 3. Enforce max_sections by merging evenly from the tail.
    # ------------------------------------------------------------------ #
    if len(combined) > max_sections:
        parts_per_section = math.ceil(len(combined) / max_sections)
        merged = []
        for i in range(0, len(combined), parts_per_section):
            chunk = combined[i : i + parts_per_section]
            m_start = chunk[0][0]
            m_end   = chunk[-1][1]
            m_tokens = []
            for c in chunk:
                m_tokens.extend(c[2])
            merged.append((m_start, m_end, m_tokens))
        combined = merged[:max_sections]

    # ------------------------------------------------------------------ #
    # 4. Separate outputs.
    # ------------------------------------------------------------------ #
    sections = [c[2] for c in combined]
    indices  = [(c[0], c[1]) for c in combined]

    return sections, indices

def ReLU(x):
    return x * (x > 0)

def sigmoid(z):
    return 1/(1 + np.exp(-z))
