import random

import torch
from einops import rearrange

from .models.transformers import NNsightModel


def residual_stream(model: NNsightModel, prompt: str) -> torch.Tensor:
    """Returns Tensor of activations with shape (positions, layers, hidden_dim)."""

    outputs = []
    with torch.no_grad():
        with model.nnsight_model.trace(prompt):
            for layer in model.layers:
                _output = layer.output[0][0].cpu().save()
                outputs.append(_output)

    return rearrange(torch.stack(outputs), "layers positions hidden_dim -> positions layers hidden_dim")


def logit_lens(model: NNsightModel, residual_stream_acts: torch.Tensor) -> torch.Tensor:
    """
    inputs: layer_outputs (shape: positions, layers, hidden_dim).
    Returns Tensor of logits with shape [positions, layers, vocab].
    """
    residual_stream_acts = residual_stream_acts.to(device=model.device)

    with torch.no_grad():
        logits = model.lm_head._module(
            model.norm._module(
                rearrange(
                    residual_stream_acts,
                    "positions layers hidden_dim -> (positions layers) hidden_dim",
                )
            )
        ).cpu()

    return rearrange(
        logits,
        "(positions layers) vocab -> positions layers vocab",
        positions=residual_stream_acts.shape[0],
        layers=residual_stream_acts.shape[1],
    )


def token_identity(
    model: NNsightModel,
    residual_stream_acts: torch.Tensor,
    seed: int = 0,
) -> torch.Tensor:
    """
    inputs: residual_stream_acts (shape: positions, layers, hidden_dim).
    Returns Tensor of logits with shape [positions, layers, vocab].
    """
    tokenizer = model.nnsight_model.tokenizer
    random.seed(seed)
    sampled_tokens = list(random.sample(range(tokenizer.vocab_size), 10))
    sep_token = tokenizer.encode(";", add_special_tokens=False)[0]
    query_token = tokenizer.encode("?", add_special_tokens=False)[0]

    identity_prompt = []
    if tokenizer.bos_token_id is not None:
        identity_prompt.append(tokenizer.bos_token_id)
    for t in sampled_tokens:
        identity_prompt += [t, t, sep_token]
    identity_prompt.append(query_token)

    P, L = residual_stream_acts.shape[:2]

    logits = []
    with torch.no_grad():
        for layer_idx in range(L):  # IMPORTANT (nnsight): must patch in order of layers
            with model.nnsight_model.trace([identity_prompt] * P):
                # patching from some activation to (same layer, final token) patchscope
                model.layers[layer_idx].output[0][:, -1, :] = residual_stream_acts[:, layer_idx, :]
                # save logits at final token
                _logits = model.lm_head.output[:, -1, :].cpu().save()
                logits.append(_logits)

    return rearrange(torch.stack(logits), "layers positions vocab -> positions layers vocab")


def argsort_logits(logits: torch.Tensor) -> torch.Tensor:
    """
    inputs: logits (shape: positions, layers, vocab).
    """
    return torch.argsort(
        rearrange(logits, "positions layers vocab -> (positions layers) vocab"),
        descending=True,
        dim=1,
    )


def reciprocal_rank(sort_indices: torch.Tensor, shape: tuple[int, ...], token: int) -> torch.Tensor:
    """
    For given token index, returns Tensor of reciprocal ranks with shape [positions, layers].
    """
    ranks = (sort_indices == token).nonzero(as_tuple=False)[:, 1].reshape(*shape)
    return 1 / (ranks + 1)
