import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from util_funcs import (
    SaeFeature,
    weighted_topk_sae_feature_agg_func,
    concat_sae_feature,
)


# sr
@torch.no_grad
def sr_compute_reps(model, batch):
    input_ids = batch["input_ids"].cuda()
    attention_mask = batch["attention_mask"].cuda()
    outputs = model(input_ids, attention_mask=attention_mask)

    attention_mask = attention_mask.unsqueeze(-1).to(outputs.last_hidden_state.dtype)

    # mean pooling
    semantic_rep = torch.sum(
        outputs.last_hidden_state * attention_mask,
        dim=1,
    ) / torch.sum(attention_mask, dim=1)

    return semantic_rep


# bge-m3
def bge_m3_compute_reps(model, batch, max_length):
    texts = batch["text"]

    input_ids = model.get_tokenizer()(
        texts, padding=False, truncation=True, max_length=max_length
    ).input_ids

    return torch.tensor(
        list(
            map(
                lambda x: x.outputs.embedding,
                model.encode(
                    list(map(lambda x: {"prompt_token_ids": x}, input_ids)),
                    use_tqdm=False,
                ),
            )
        ),
        device="cuda",
    )


# ps wanda
def _get_model(model):
    if getattr(model, "model", None) is not None:
        return model.model, "model", [nn.Linear]
    if getattr(model, "transformer", None) is not None:
        return model.transformer, "transformer", [Conv1D]
    raise ValueError("Model does not have attribute 'model' or 'transformer'")


def _get_model_layers(model):
    if getattr(model, "layers", None) is not None:
        return model.layers, "layers"
    if getattr(model, "h", None) is not None:
        return model.h, "h"
    raise ValueError("Model does not have attribute 'layers' or 'h'")


def find_layers(module, layers=[nn.Linear], name=""):
    """
    Recursively find the layers of a certain type in a module.

    Args:
        module (nn.Module): PyTorch module.
        layers (list): List of layer types to find.
        name (str): Name of the module.

    Returns:
        dict: Dictionary of layers of the given type(s) within the module.
    """
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(
            find_layers(
                child,
                layers=layers,
                name=name + "." + name1 if name != "" else name1,
            )
        )
    return res


def vectorize_ps(ps):
    """
    ps_masks: dict, key: subset name, value: mask for the subset
    """
    keys = sorted(ps.keys())
    ps_list = [torch.flatten(ps[key], start_dim=1) for key in keys]
    return torch.cat(ps_list, dim=1)


def move_to_device(batch, device):
    for key in batch:
        if isinstance(batch[key], torch.Tensor):
            batch[key] = batch[key].to(device)
    return batch


def ps_wanda_compute_reps(
    origin_model, batch, prune_ratio=0.25, common_mask=None, top_layer_num=-1
):
    model, model_name, modules = _get_model(origin_model)
    layers, layers_name = _get_model_layers(model)

    batch = move_to_device(batch, model.device)

    hidden_states = model.get_input_embeddings()(
        batch["input_ids"]
    )  # shape: (batch_size, seq_len, hidden_size), hidden states of current layer

    if model_name == "transformer":
        # NOTE: for gpt2, the position_ids are required for the layer input
        hidden_states = hidden_states + model.wpe(batch["position_ids"])
    elif model_name == "model":
        position_embeddings = model.rotary_emb(hidden_states, batch["position_ids"])

    attention_mask = (
        batch["attention_mask"].unsqueeze(-1).type(model.dtype)
    )  # shape: (batch_size, seq_len, 1)

    ps_masks = dict()  # key: subset name, value: mask for the subset

    for i, layer in enumerate(layers):
        if top_layer_num < 0 or i >= len(layers) - top_layer_num:
            subset = find_layers(
                layer, name=f"{model_name}.{layers_name}.{i}", layers=modules
            )

            subset_inputs = dict()  # key: subset name, value: input to the subset

            # register hooks to save the input to the subset
            def hook_wrapper(name):
                def hook(module, input, output):
                    # NOTE: the input to the subset is a tuple
                    subset_inputs[name] = input[0]

                return hook

            handles = []
            for name, module in subset.items():
                handles.append(module.register_forward_hook(hook_wrapper(name)))

        # forward pass to get the input to the subset
        # NOTE: position embeddings are required for the layer input
        extra_args = {}
        if model_name == "model":
            extra_args["position_embeddings"] = position_embeddings
        next_layer_hidden_states = layer(hidden_states, **extra_args)[0]
        hidden_states = next_layer_hidden_states

        # skip the shallow layers
        if top_layer_num >= 0 and i < len(layers) - top_layer_num:
            continue

        # remove the hooks
        for handle in handles:
            handle.remove()

        # we now have the input to the subset and evaluate the importance of the parameters subset by subset
        for name in subset:
            # analyze sample by sample
            input = (
                subset_inputs[name] * attention_mask
            )  # shape: (batch_size, seq_len, input_size), padding tokens are masked and have no effect to the norm of the input
            weight = subset[name].weight.data  # shape: (output_size, input_size)

            if model_name == "transformer":
                weight = weight.T

            ps_mask = torch.ones_like(weight, dtype=torch.bool).repeat(
                input.shape[0], 1, 1
            )  # shape: (batch_size, output_size, input_size)

            # wanda score
            wanda_score = torch.abs(weight).unsqueeze(0) * torch.norm(
                input, p=2, dim=1, keepdim=True
            )  # shape: (batch_size, output_size, input_size)

            # TODO: we do not consider structured and semi-structured sparsity for now
            _, pruned_idx = torch.topk(
                wanda_score,
                int(wanda_score.shape[-1] * prune_ratio),
                dim=-1,
                largest=False,
                sorted=False,
            )  # shape: (batch_size, output_size, int(input_size * sparse_ratio))
            ps_mask.scatter_(-1, pruned_idx, False)

            assert name not in ps_masks
            ps_masks[name] = ps_mask

    vec_pc_mask = vectorize_ps(ps_masks)
    if common_mask is not None:
        vec_pc_mask = vec_pc_mask[:, common_mask]
    return vec_pc_mask


def sps_compute_feature(model, batch, select_ratio=0.1, threshold=0.002):
    model.zero_grad()

    input_ids = batch["input_ids"]
    assert input_ids.size(0) == 1

    loss = model(input_ids.cuda(), labels=input_ids.clone().cuda())[0]
    loss.backward()

    grad_dict = dict()
    for name, param in model.named_parameters():
        grad_dict[name] = param.grad.detach()

    grads = torch.abs(
        torch.cat([grad_dict[name].flatten() for name in sorted(grad_dict.keys())])
    )
    # _, indices = torch.topk(grads, int(grads.size(0) * select_ratio))
    indices = torch.where(grads >= threshold)[0]

    return indices


def normalize(x):
    """
    x: torch.Tensor, shape: [*, act_dim]
    """
    # max-min normalization
    return (x - x.min(dim=-1, keepdim=True).values) / (
        x.max(dim=-1, keepdim=True).values - x.min(dim=-1, keepdim=True).values
    )
    # return x / x.norm(dim=-1, keepdim=True, p=2)
    # return x - x.min(dim=-1, keepdim=True)


@torch.no_grad
def rds_compute_feature(batch, model=None):
    input_ids = batch["input_ids"].cuda()
    attention_mask = batch["attention_mask"].cuda()

    hidden_states = model(input_ids).last_hidden_state

    ids = torch.arange(input_ids.size(0), device=input_ids.device)
    pos = attention_mask.sum(dim=-1) - 1
    sentence_embeddings = hidden_states[ids, pos].float()

    return torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)


@torch.no_grad
def bge_compute_feature(batch, model=None):
    bsz, seq_len = batch["input_ids"].shape

    if model.config.max_position_embeddings > seq_len:
        model.config.max_position_embeddings = seq_len

    chunk_mask = torch.any(
        batch["attention_mask"]
        .cuda()
        .view(bsz, -1, model.config.max_position_embeddings),
        dim=-1,
    ).unsqueeze(-1)
    assert seq_len % model.config.max_position_embeddings == 0

    inputs = {
        "input_ids": batch["input_ids"]
        .view(-1, model.config.max_position_embeddings)
        .cuda(),
        "attention_mask": batch["attention_mask"]
        .view(-1, model.config.max_position_embeddings)
        .cuda(),
    }

    if "token_type_ids" in batch:
        inputs["token_type_ids"] = (
            batch["token_type_ids"]
            .view(-1, model.config.max_position_embeddings)
            .cuda(),
        )

    model_output = model(**inputs)
    # Perform pooling. In this case, cls pooling.
    sentence_embeddings = model_output[0][:, 0]

    sentence_embeddings = sentence_embeddings.view(
        bsz, -1, sentence_embeddings.shape[-1]
    )

    # pooling
    sentence_embeddings = torch.sum(
        sentence_embeddings * chunk_mask, dim=1
    ) / torch.sum(chunk_mask, dim=1)
    # normalize embeddings
    sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)

    return sentence_embeddings  # shape: [bsz, hidden_size]


@torch.no_grad
def llm2vec_compute_feature(batch, model=None):
    sentence_embeddings = model.encode(batch, show_progress_bar=False)
    return torch.nn.functional.normalize(sentence_embeddings, p=2, dim=-1)


@torch.no_grad
def sae_compute_feature(batch, model=None, sae_models=None, k=192, avg_level="sample"):
    input_ids = batch["input_ids"].cuda()
    lengths = torch.sum(batch["attention_mask"], dim=-1)

    hidden_states = model(input_ids, output_hidden_states=True).hidden_states

    offset = 0
    batch_sae_features = [None] * input_ids.size(0)

    for name in sae_models:
        layer_index = int(name.split(".")[-1].strip()) + 1

        sae_model = sae_models[name]

        # [batch_size, seq_len, latent_dim]
        cur_sample_all_tokens_sae_feature = sae_model.encode(hidden_states[layer_index])

        for batch_index in range(input_ids.size(0)):
            cur_sample_sae_feature = weighted_topk_sae_feature_agg_func(
                SaeFeature(
                    cur_sample_all_tokens_sae_feature.top_acts[
                        batch_index, : lengths[batch_index]
                    ].to(dtype=torch.float32),
                    cur_sample_all_tokens_sae_feature.top_indices[
                        batch_index, : lengths[batch_index]
                    ]
                    + offset,
                ),
                k=k,
                avg_level=avg_level,
            )

            if batch_sae_features[batch_index] is None:
                batch_sae_features[batch_index] = cur_sample_sae_feature
            else:
                batch_sae_features[batch_index] = concat_sae_feature(
                    [batch_sae_features[batch_index], cur_sample_sae_feature]
                )

        offset += sae_models[name].num_latents

    return batch_sae_features


silu_func = torch.nn.SiLU()


@torch.no_grad
def nosae_compute_feature(batch, model=None, layers=[], k=192, avg_level="sample"):
    input_ids = batch["input_ids"].cuda()
    lengths = torch.sum(batch["attention_mask"], dim=-1)

    layers = list(map(int, layers))

    feature_output = []
    def hook(module, input, output):
        feature_output.append(output.detach())

    layer_handles = dict()
    for layer_index in layers:
        layer_handles[layer_index] = model.model.layers[
            layer_index
        ].mlp.gate_proj.register_forward_hook(hook)

    model(input_ids, output_hidden_states=True)

    offset = 0
    batch_nosae_features = [None] * input_ids.size(0)
    for index, layer_index in enumerate(layers):
        # [batch_size, seq_len, hidden_size]
        silu_out = silu_func(feature_output[index])

        # Refer to CATS: https://arxiv.org/pdf/2404.08763
        acts, indices = torch.topk(torch.abs(silu_out), k=k, dim=-1, largest=True)
        for batch_index in range(input_ids.size(0)):
            cur_sample_sae_feature = weighted_topk_sae_feature_agg_func(
                SaeFeature(
                    acts[batch_index, : lengths[batch_index]].to(dtype=torch.float32),
                    indices[batch_index, : lengths[batch_index]] + offset,
                ),
                k=-1,
                avg_level=avg_level,
            )
            if batch_nosae_features[batch_index] is None:
                batch_nosae_features[batch_index] = cur_sample_sae_feature
            else:
                batch_nosae_features[batch_index] = concat_sae_feature(
                    [batch_nosae_features[batch_index], cur_sample_sae_feature]
                )

        offset += silu_out.size(-1)

    for layer_index in layers:
        layer_handles[layer_index].remove()

    return batch_nosae_features
