from collections import defaultdict
from transformers import CLIPProcessor, CLIPVisionModel, CLIPModel, CLIPVisionModelWithProjection, \
    CLIPTextModelWithProjection
import types
import os
import matplotlib.pyplot as plt
import torch
import argparse
from PIL import Image
import numpy as np
from tqdm import tqdm
import cv2
from sklearn.linear_model import OrthogonalMatchingPursuit


class CLIPVisionEncoderWithOutputs:
    def __init__(self, model_name="openai/clip-vit-large-patch14", device="cuda"):
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model = CLIPVisionModelWithProjection.from_pretrained(model_name, attn_implementation="eager").to(device)
        self.projection = self.model.visual_projection
        self.layer_outputs = defaultdict(dict)
        self.device = device
        self._modify_attention_modules()
        self._register_hooks()

    def _modify_attention_modules(self):
        for i, layer in enumerate(self.model.vision_model.encoder.layers):
            attention_module = layer.self_attn
            attn_class_name = attention_module.__class__.__name__
            original_forward = attention_module.forward
            def new_forward(self, hidden_states, attention_mask=None, causal_attention_mask=None,
                            output_attentions=False, layer_idx=i):
                if hasattr(self, '_original_forward'):
                    attn_output, attn_weights = self._original_forward(
                        hidden_states=hidden_states,
                        attention_mask=attention_mask,
                        causal_attention_mask=causal_attention_mask,
                        output_attentions=True
                    )
                else:
                    parent_class = self.__class__.__mro__[1]
                    parent_forward = getattr(parent_class, 'forward')
                    attn_output, attn_weights = parent_forward(
                        self,
                        hidden_states=hidden_states,
                        attention_mask=attention_mask,
                        causal_attention_mask=causal_attention_mask,
                        output_attentions=True
                    )
                layer_outputs = getattr(self, 'layer_outputs', None)
                if layer_outputs is None:
                    return (attn_output, attn_weights) if output_attentions else attn_output
                layer_outputs[layer_idx]['attention_weights'] = attn_weights.detach().clone()
                bsz, tgt_len, embed_dim = hidden_states.size()
                q_proj = getattr(self, 'q_proj', None) or getattr(self, 'query', None)
                k_proj = getattr(self, 'k_proj', None) or getattr(self, 'key', None)
                v_proj = getattr(self, 'v_proj', None) or getattr(self, 'value', None)

                if q_proj and k_proj and v_proj:
                    query_states = q_proj(hidden_states)
                    key_states = k_proj(hidden_states)
                    value_states = v_proj(hidden_states)

                    num_heads = getattr(self, 'num_heads', None) or getattr(self, 'num_attention_heads', None)
                    head_dim = getattr(self, 'head_dim', None)
                    if head_dim is None and num_heads is not None:
                        head_dim = embed_dim // num_heads

                    if num_heads is not None and head_dim is not None:
                        query_states = query_states.view(bsz, -1, num_heads, head_dim).transpose(1, 2)
                        key_states = key_states.view(bsz, -1, num_heads, head_dim).transpose(1, 2)
                        value_states = value_states.view(bsz, -1, num_heads, head_dim).transpose(1, 2)

                        cls_attn_weights = attn_weights[:, :, 0, :]  # [bsz, num_heads, 1, seq_len]

                        cls_attn_weights_expanded = cls_attn_weights.unsqueeze(-1)  # [bsz, num_heads, 1, seq_len, 1]
                        value_expanded = value_states  # [bsz, num_heads, 1, seq_len, head_dim]
                        cls_subvalues = cls_attn_weights_expanded * value_expanded
                        layer_outputs[layer_idx]['cls_subvalues'] = cls_subvalues.detach().clone()

                layer_outputs[layer_idx]['attention_output'] = attn_output.detach().clone()

                return (attn_output, attn_weights) if output_attentions else attn_output

            attention_module._original_forward = original_forward

            bound_method = types.MethodType(new_forward, attention_module)
            attention_module.forward = bound_method
            attention_module.layer_outputs = self.layer_outputs

    def _register_hooks(self):
        for i, layer in enumerate(self.model.vision_model.encoder.layers):
            layer.register_forward_pre_hook(self._get_layer_input_hook(i))

            layer.register_forward_hook(self._get_layer_output_hook(i))

            layer.layer_norm2.register_forward_hook(self._get_residual_input_hook(i))

            layer.mlp.register_forward_hook(self._get_ffn_output_hook(i))

            layer.mlp.activation_fn.register_forward_hook(self._get_activation_output_hook(i))

    def _get_layer_input_hook(self, layer_idx):
        def hook(module, input):
            self.layer_outputs[layer_idx]['layer_input'] = input[0].detach().clone()

        return hook

    def _get_residual_input_hook(self, layer_idx):
        def hook(module, input, output):
            self.layer_outputs[layer_idx]['residual_output'] = input[0].detach().clone()

        return hook

    def _get_activation_output_hook(self, layer_idx):
        def hook(module, input, output):
            self.layer_outputs[layer_idx]['coefficient'] = output.detach().clone()

        return hook

    def _get_layer_output_hook(self, layer_idx):
        def hook(module, input, output):
            self.layer_outputs[layer_idx]['layer_output'] = output[0].detach().clone()

            if all(k in self.layer_outputs[layer_idx] for k in ['layer_input', 'attention_output', 'ffn_output']):
                layer_input = self.layer_outputs[layer_idx]['layer_input']
                attn_output = self.layer_outputs[layer_idx]['attention_output']
                ffn_output = self.layer_outputs[layer_idx]['ffn_output']
                layer_output = self.layer_outputs[layer_idx]['layer_output']

                reconstructed = layer_input + attn_output + ffn_output
                error = torch.mean(torch.abs(reconstructed - layer_output))
                self.layer_outputs[layer_idx]['reconstruction_error'] = error.item()

        return hook

    def _get_ffn_output_hook(self, layer_idx):
        def hook(module, input, output):
            self.layer_outputs[layer_idx]['ffn_output'] = output.detach().clone()

        return hook

    def extract_features(self, image_path, save_outputs=False):
        image = Image.open(image_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt").to(self.device)

        self.layer_outputs.clear()

        with torch.no_grad():
            outputs = self.model(
                pixel_values=inputs.pixel_values,
                output_attentions=True
            )
        attention_weights = outputs.attentions
        for i, attn in enumerate(attention_weights):
            self.layer_outputs[i]['attention_weights'] = attn.detach().clone()
        if save_outputs:
            torch.save(self.layer_outputs, "clip_vision_outputs_transformers.pt")
            print("Saved all outputs to clip_vision_outputs_transformers.pt")

        return outputs.image_embeds, self.layer_outputs


def get_fc2_params(model, layer_num):
    return model.vision_model.encoder.layers[layer_num].mlp.fc2.weight.data.T


def get_fc2_bias(model, layer_num):
    return model.vision_model.encoder.layers[layer_num].mlp.fc2.bias.data.unsqueeze(0)


def get_bsvalues(vector, extractor, u, s, label_emb):
    vector = (vector - u) / torch.sqrt(s + 1e-12)
    shape = [1] * (len(vector.shape) - 1) + [-1]
    weight = extractor.model.vision_model.post_layernorm.weight.reshape(shape)
    bias = extractor.model.vision_model.post_layernorm.bias.reshape(shape)
    vector_ln = weight * vector + bias
    vector_bsvalues = 100 * extractor.projection(vector_ln) @ label_emb.T
    return vector_bsvalues


def get_consine(vector, extractor, s, label_emb):
    u = vector.mean(dim=-1, keepdim=True)
    vector = (vector - u) / torch.sqrt(s + 1e-12)
    shape = [1] * (len(vector.shape) - 1) + [-1]
    weight = extractor.model.vision_model.post_layernorm.weight.reshape(shape)
    bias = extractor.model.vision_model.post_layernorm.bias.reshape(shape)
    vector_ln = weight * vector  # + bias
    visual_embedding = extractor.projection(vector_ln)
    visual_embedding /= visual_embedding.norm(dim=-1, keepdim=True)
    label_emb = label_emb.reshape(-1, 1)
    cosine = visual_embedding @ label_emb
    return cosine[..., 0]


def dict_to_list(dictionary, div=1):
    return [[key, value / div] for key, value in dictionary.items()]


def sum_by_first_element(data_list):
    sums_by_col = {}

    for item in data_list:
        key_parts = item[0].split('_')
        col_index = int(key_parts[0])
        value = float(item[1])
        if col_index in sums_by_col:
            sums_by_col[col_index] += value
        else:
            sums_by_col[col_index] = value
    result_list = [[col, sum_value] for col, sum_value in sums_by_col.items()]
    result_list.sort(key=lambda x: x[0])

    return result_list


if __name__  == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    extractor = CLIPVisionEncoderWithOutputs(model_name="openai/clip-vit-large-patch14")
    model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14",
                                                        attn_implementation="eager").to(device)
    LAYER_NUM = len(extractor.model.vision_model.encoder.layers)
    HEAD_NUM = extractor.model.vision_model.encoder.layers[0].self_attn.num_heads
    HEAD_DIM = extractor.model.vision_model.encoder.layers[0].self_attn.head_dim


    def transfer_output(model_output):
        all_pos_layer_input = []
        all_pos_attn_output = []
        all_pos_residual_output = []
        all_pos_ffn_output = []
        all_pos_layer_output = []
        all_last_attn_subvalues = []
        all_pos_coefficient_scores = []
        all_attn_scores = []
        for layer_i in range(LAYER_NUM):
            if layer_i > 0:
                cur_layer_input = model_output[layer_i - 1]['layer_output']
            else:
                cur_layer_input = model_output[layer_i]['layer_input']
            cur_attn_output = model_output[layer_i]['attention_output']
            cur_residual_output = model_output[layer_i]['residual_output']
            cur_ffn_output = model_output[layer_i]['ffn_output']
            cur_layer_output = model_output[layer_i]['layer_output']
            cur_last_attn_subvalues = model_output[layer_i]['cls_subvalues']
            cur_coefficient_scores = model_output[layer_i]['coefficient']
            cur_attn_weights = model_output[layer_i]['attention_weights']
            all_pos_layer_input.append(cur_layer_input[0].tolist())
            all_pos_attn_output.append(cur_attn_output[0].tolist())
            all_pos_residual_output.append(cur_residual_output[0].tolist())
            all_pos_ffn_output.append(cur_ffn_output[0].tolist())
            all_pos_layer_output.append(cur_layer_output[0].tolist())
            all_last_attn_subvalues.append(cur_last_attn_subvalues[0].tolist())
            all_pos_coefficient_scores.append(cur_coefficient_scores[0].tolist())
            all_attn_scores.append(cur_attn_weights)
        return all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, \
            all_pos_layer_output, all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores

    # detecting knowledge neuron
    reference = "text"
    inputs = extractor.processor(text=["taylor swift"], return_tensors="pt", padding=True).to("cuda")
    with torch.no_grad():
        text_features = model(**inputs).text_embeds
        text_features /= text_features.norm(dim=1, keepdim=True)
    ffn_subvalue_list = {}
    for img in tqdm(range(1, 11)):
        image_path = os.path.join("VisEnt/celebrity/taylor swift", str(img).zfill(3) + ".jpeg")
        with torch.no_grad():
            features, layer_outputs = extractor.extract_features(image_path, save_outputs=False)
            features_unit = features / features.norm(dim=1, keepdim=True)
            reference_emb = text_features if reference == "text" else features_unit
            x = layer_outputs[LAYER_NUM - 1]['layer_output'][:, 0]
            u = x.mean(-1, keepdim=True)
            s = (x - u).pow(2).mean(-1, keepdim=True)
            u, s = u.reshape(-1), s.reshape(-1)
            all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, \
                all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(layer_outputs)
            all_ffn_subvalues = []
            for layer_i in range(LAYER_NUM):
                coefficient_scores = torch.tensor(all_pos_coefficient_scores[layer_i][0]).unsqueeze(1).to(device)
                fc2_vectors = get_fc2_params(extractor.model, layer_i)
                ffn_subvalues = (coefficient_scores * fc2_vectors).T
                all_ffn_subvalues.append(ffn_subvalues)
            for layer_i in range(LAYER_NUM):
                cur_ffn_subvalues = all_ffn_subvalues[layer_i].T
                cur_residual = torch.tensor(all_pos_residual_output[layer_i][0]).unsqueeze(0).to(device)
                cur_residual += get_fc2_bias(extractor.model, layer_i)
                origin_prob_log = get_consine(cur_residual, extractor, s, reference_emb)
                cur_ffn_subvalues_plus = cur_ffn_subvalues + cur_residual
                cur_ffn_subvalues_probs_log = get_consine(cur_ffn_subvalues_plus, extractor, s, reference_emb)
                cur_ffn_subvalues_probs_log_increase = cur_ffn_subvalues_probs_log - origin_prob_log
                for index, ffn_increase in enumerate(cur_ffn_subvalues_probs_log_increase):
                    if str(layer_i) + "_" + str(index) not in ffn_subvalue_list:
                        ffn_subvalue_list[str(layer_i) + "_" + str(index)] = 0
                    ffn_subvalue_list[str(layer_i) + "_" + str(index)] += ffn_increase.item()

            for test_layer in range(LAYER_NUM):
                cur_layer_input = torch.tensor(all_pos_layer_input[test_layer]).to(device)
                cur_v_heads_recompute = torch.tensor(all_last_attn_subvalues[test_layer]).permute(1, 0, 2).to(device)
                cur_attn_o_split = extractor.model.vision_model.encoder.layers[test_layer].self_attn.out_proj.weight.data.view(1, -1, HEAD_NUM, HEAD_DIM)
                cur_attn_o_recompute = cur_attn_o_split * (cur_v_heads_recompute.unsqueeze(1))
                cur_attn_o_recompute = cur_attn_o_recompute.permute(0,2,3,1)
                cur_layer_input_last = cur_layer_input[:1] + extractor.model.vision_model.encoder.layers[test_layer].self_attn.out_proj.bias.reshape(1,-1)
                origin_prob = get_consine(cur_layer_input_last, extractor, s,  reference_emb)
                cur_attn_o_recompute = cur_attn_o_recompute.sum(dim=0)
                cur_attn_o_head_plus = cur_attn_o_recompute + cur_layer_input_last.reshape(1,1,-1)
                cur_attn_plus_probs = get_consine(cur_attn_o_head_plus, extractor, s, reference_emb)
                cur_attn_plus_probs_increase = cur_attn_plus_probs - origin_prob
                for head_index in range(cur_attn_plus_probs_increase.size(0)):
                    for attn_neuron_index in range(cur_attn_plus_probs_increase.size(1)):
                        key = str(test_layer)+"_"+str(head_index)+"_"+str(attn_neuron_index)
                        if key not in ffn_subvalue_list:
                            ffn_subvalue_list[key] = 0
                        ffn_subvalue_list[key] += cur_attn_plus_probs_increase[head_index][attn_neuron_index].item()

    ffn_subvalue_list = dict_to_list(ffn_subvalue_list, div=1)
    ffn_subvalue_list_sort = sorted(ffn_subvalue_list, key=lambda x: x[-1])[::-1]
    print("Top 10 knowledge neurons:")
    print(ffn_subvalue_list_sort[:10])

    # visual explanation
    neuron = ffn_subvalue_list_sort[0][0]
    reference = "text"
    image_path = os.path.join("VisEnt/celebrity/taylor swift", "001.jpeg")
    if len(neuron.split("_")) == 2:
        l_n = neuron
        with torch.no_grad():
            features, layer_outputs = extractor.extract_features(image_path, save_outputs=False)
            features_unit = features / features.norm(dim=1, keepdim=True)
            reference_emb = text_features if reference == "text" else features_unit
            x = layer_outputs[LAYER_NUM - 1]['layer_output'][:, 0]
            u = x.mean(-1, keepdim=True)
            s = (x - u).pow(2).mean(-1, keepdim=True)
            u, s = u.reshape(-1), s.reshape(-1)
            all_pos_layer_input, all_pos_attn_output, all_pos_residual_output, all_pos_ffn_output, all_pos_layer_output, \
                all_last_attn_subvalues, all_pos_coefficient_scores, all_attn_scores = transfer_output(layer_outputs)
            all_ffn_subvalues = []
        with torch.no_grad():
            ffn_layer, ffn_neuron = l_n.split("_")
            ffn_layer, ffn_neuron = int(ffn_layer), int(ffn_neuron)
            cur_ffn_neuron = ffn_neuron
            ffn_neuron_key = extractor.model.vision_model.encoder.layers[ffn_layer].mlp.fc1.weight.data[cur_ffn_neuron,
                             :]
            ln = extractor.model.vision_model.encoder.layers[ffn_layer].layer_norm2
            x = layer_outputs[ffn_layer]['residual_output'][:, 0]
            u = x.mean(-1, keepdim=True)
            s = (x - u).pow(2).mean(-1, keepdim=True)
            u, s = u.reshape(-1), s.reshape(-1)
            pos_dic = {}
            # cur_inner_all = torch.sum(ln(torch.tensor(all_pos_residual_output[attn_layer][attn_pos]).to(device))*ffn_neuron_key, -1)
            for layer_i in range(ffn_layer + 1):
                cur_layer_input = torch.tensor(all_pos_layer_input[layer_i]).to(device)
                cur_v_heads_recompute = torch.tensor(all_last_attn_subvalues[layer_i]).permute(1, 0, 2).to(device)
                cur_attn_o_split = extractor.model.vision_model.encoder.layers[
                    layer_i].self_attn.out_proj.weight.data.view(1, -1, HEAD_NUM, HEAD_DIM)
                cur_attn_o_recompute = cur_attn_o_split * (cur_v_heads_recompute.unsqueeze(1))
                cur_attn_o_recompute = cur_attn_o_recompute.permute(0, 2, 3, 1)
                cur_attn_o_recompute = cur_attn_o_recompute.sum(dim=[1, 2])
                for pos_index in range(cur_attn_o_recompute.size(0)):
                    cur_layer_neurons = cur_attn_o_recompute[pos_index]
                    cur_layer_neurons -= cur_layer_neurons.mean(dim=-1, keepdim=True)
                    cur_layer_neurons *= ln.weight
                    cur_layer_neurons_innerproduct = torch.sum(cur_layer_neurons * ffn_neuron_key, dim=-1)
                    pos_dic[pos_index] = cur_layer_neurons_innerproduct.item()
        saliency_map = np.array(list(pos_dic.values())[1:]).reshape(16, 16)
    else:
        l_h_n = neuron
        with torch.no_grad():
            attn_layer, attn_head, attn_neuron = l_h_n.split("_")
            attn_layer, attn_head, attn_neuron = int(attn_layer), int(attn_head), int(attn_neuron)
            cur_attn_neuron = attn_head * HEAD_DIM + attn_neuron
            attn_neuron_key = extractor.model.vision_model.encoder.layers[attn_layer].self_attn.v_proj.weight.data[
                              cur_attn_neuron, :]
            all_contribution = all_last_attn_subvalues[attn_layer][attn_head][0][attn_neuron]
            revert = 1
            if all_contribution < 0:
                print("negative contribution")
                revert = -1
            ln = extractor.model.vision_model.encoder.layers[attn_layer].layer_norm1
            x = torch.tensor(all_pos_layer_input[attn_layer][0]).to(device)
            u = x.mean(-1, keepdim=True)
            s = (x - u).pow(2).mean(-1, keepdim=True)
            u, s = u.reshape(-1), s.reshape(-1)
            pos_dic = {}
            # cur_inner_all = torch.sum(ln(torch.tensor(all_pos_residual_output[attn_layer][attn_pos]).to(device))*ffn_neuron_key, -1)
            for layer_i in range(attn_layer):
                cur_layer_input = torch.tensor(all_pos_layer_input[layer_i]).to(device)
                cur_v_heads_recompute = torch.tensor(all_last_attn_subvalues[layer_i]).permute(1, 0, 2).to(device)
                cur_attn_o_split = extractor.model.vision_model.encoder.layers[
                    layer_i].self_attn.out_proj.weight.data.view(1, -1, HEAD_NUM, HEAD_DIM)
                cur_attn_o_recompute = cur_attn_o_split * (cur_v_heads_recompute.unsqueeze(1))
                cur_attn_o_recompute = cur_attn_o_recompute.permute(0, 2, 3, 1)
                cur_attn_o_recompute = cur_attn_o_recompute.sum(dim=[1, 2])
                for pos_index in range(cur_attn_o_recompute.size(0)):
                    cur_layer_neurons = cur_attn_o_recompute[pos_index]
                    cur_layer_neurons -= cur_layer_neurons.mean(dim=-1, keepdim=True)
                    cur_layer_neurons *= ln.weight
                    cur_layer_neurons_innerproduct = torch.sum(cur_layer_neurons * attn_neuron_key, dim=-1)
                    pos_dic[pos_index] = cur_layer_neurons_innerproduct.item() * revert
        saliency_map = np.array(list(pos_dic.values())[1:]).reshape(16, 16)
    image = Image.open(image_path).convert("RGB")
    inputs = extractor.processor(images=image, return_tensors="pt")['pixel_values'][:, [2, 1, 0]]
    rgb_image = inputs.cpu() * torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1) + torch.tensor(
        [0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
    rgb_img = (rgb_image[0].permute(1, 2, 0)).detach().numpy()
    image = rgb_img
    image = (image * 255).astype(np.uint8)
    saliency_resized = cv2.resize(saliency_map, (224, 224), interpolation=cv2.INTER_CUBIC)
    saliency_resized = (saliency_resized - saliency_resized.min()) / (
            saliency_resized.max() - saliency_resized.min())
    saliency_resized = (saliency_resized * 255).astype(np.uint8)
    saliency_colored = cv2.applyColorMap(saliency_resized, cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(image, 0.6, saliency_colored, 0.4, 0)
    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
    plt.axis("off")

    plt.tight_layout()
    plt.show()



    # linguistic explanation
    projection = extractor.visual_projection
    ln = extractor.model.vision_model.post_layernorm
    u = -0.0133
    s = 0.3866
    weight = ln.weight
    file_path = '20k.txt'
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    # Read all words into a list
    with open(file_path, 'r', encoding='utf-8') as file:
        words = [line.strip() for line in file]
    text_features_list = []
    with torch.no_grad():
        for i in range(0, len(words), 500):
            inputs = processor(text=words[i: (i+500)], return_tensors="pt", padding=True).to("cuda")
            text_features = model.get_text_features(**inputs)
            text_features /= text_features.norm(dim=1, keepdim=True)
            text_features_list.append(text_features)
        text_features_list = torch.cat(text_features_list, dim=0)

    concepts = []
    idx = neuron
    if len(neuron.split("_")) == 2:
        L, K = idx.split("_")
        knowledge_neuron = extractor.model.vision_model.encoder.layers[int(L)].mlp.fc2.weight.data[:, int(K)]
        knowledge_neuron = (knowledge_neuron-knowledge_neuron.mean(dim=-1, keepdim=True))/np.sqrt(s + 1e-12)
        knowledge_neuron_ln = knowledge_neuron * weight
        knowledge_representation = projection(knowledge_neuron_ln).unsqueeze(0).cuda()
        knowledge_representation /= knowledge_representation.norm(dim=1, keepdim=True)
        omp = OrthogonalMatchingPursuit(n_nonzero_coefs=5)
        omp.fit(text_features_list.cpu().numpy().T, knowledge_representation.cpu().detach().numpy()[0])
        concepts.append([words[i] for i in np.argsort(omp.coef_)[::-1][:5]])
    else:
        revert = all_contribution>0
        L, H, K = idx.split("_")
        K = int(H)*64+int(K)
        knowledge_neuron = extractor.model.vision_model.encoder.layers[int(L)].self_attn.out_proj.weight.data[:, K]
        knowledge_neuron = (knowledge_neuron-knowledge_neuron.mean(dim=-1, keepdim=True))/np.sqrt(s + 1e-12)
        knowledge_neuron_ln = knowledge_neuron * weight * revert
        knowledge_representation = projection(knowledge_neuron_ln).unsqueeze(0).cuda()
        knowledge_representation /= knowledge_representation.norm(dim=1, keepdim=True)
        omp = OrthogonalMatchingPursuit(n_nonzero_coefs=5)
        omp.fit(text_features_list.cpu().numpy().T, knowledge_representation.cpu().detach().numpy()[0])
        concepts.append([words[i] for i in np.argsort(omp.coef_)[::-1][:5]])
    print(concepts)


    #detecting query neurons
    if len(neuron.split("_")) == 2:
        l_n = neuron
        with torch.no_grad():
            ffn_layer, ffn_neuron = l_n.split("_")
            ffn_layer, ffn_neuron = int(ffn_layer), int(ffn_neuron)
            cur_ffn_neuron = ffn_neuron
            ffn_neuron_key = extractor.model.vision_model.encoder.layers[ffn_layer].mlp.fc1.weight.data[cur_ffn_neuron,
                             :]
            ln = extractor.model.vision_model.encoder.layers[ffn_layer].layer_norm2
            x = layer_outputs[ffn_layer]['residual_output'][:, 0]
            u = x.mean(-1, keepdim=True)
            s = (x - u).pow(2).mean(-1, keepdim=True)
            u, s = u.reshape(-1), s.reshape(-1)
            head_neuron_dic = {}
            for layer_i in range(ffn_layer):
                cur_layer_input = torch.tensor(all_pos_coefficient_scores[layer_i][0]).to(device)
                cur_layer_recompute = extractor.model.vision_model.encoder.layers[
                                          layer_i].mlp.fc2.weight.data * cur_layer_input.unsqueeze(0)
                for neuron_index in range(cur_layer_input.size(0)):
                    cur_layer_neurons = cur_layer_recompute[:, neuron_index]
                    cur_layer_neurons -= cur_layer_neurons.mean(dim=-1, keepdim=True)
                    cur_layer_neurons *= ln.weight
                    cur_layer_neurons_innerproduct = torch.sum(cur_layer_neurons * ffn_neuron_key, dim=-1)
                    head_neuron_dic[str(layer_i) + "_" + str(neuron_index)] = cur_layer_neurons_innerproduct.item()
            # cur_inner_all = torch.sum(ln(torch.tensor(all_pos_residual_output[attn_layer][attn_pos]).to(device))*ffn_neuron_key, -1)
            for layer_i in range(ffn_layer + 1):
                cur_layer_input = torch.tensor(all_pos_layer_input[layer_i]).to(device)
                cur_v_heads_recompute = torch.tensor(all_last_attn_subvalues[layer_i]).permute(1, 0, 2).to(device)
                cur_attn_o_split = extractor.model.vision_model.encoder.layers[
                    layer_i].self_attn.out_proj.weight.data.view(1, -1, HEAD_NUM, HEAD_DIM)
                cur_attn_o_recompute = cur_attn_o_split * (cur_v_heads_recompute.unsqueeze(1))
                cur_attn_o_recompute = cur_attn_o_recompute.permute(0, 2, 3, 1)
                cur_attn_o_recompute = cur_attn_o_recompute.sum(dim=[0])
                for head_index in range(cur_attn_o_recompute.size(0)):
                    for neuron_index in range(cur_attn_o_recompute.size(1)):
                        cur_layer_neurons = cur_attn_o_recompute[head_index][neuron_index]
                        cur_layer_neurons -= cur_layer_neurons.mean(dim=-1, keepdim=True)
                        cur_layer_neurons *= ln.weight
                        cur_layer_neurons_innerproduct = torch.sum(cur_layer_neurons * ffn_neuron_key, dim=-1)
                        head_neuron_dic[str(layer_i) + "_" + str(head_index) + "_" + str(
                            neuron_index)] = cur_layer_neurons_innerproduct.item()
        head_neuron_list = dict_to_list(head_neuron_dic)
        head_neuron_list_sort = sorted(head_neuron_list, key=lambda x: x[-1])[::-1]
    else:
        l_h_n = neuron
        with torch.no_grad():
            attn_layer, attn_head, attn_neuron = l_h_n.split("_")
            attn_layer, attn_head, attn_neuron = int(attn_layer), int(attn_head), int(attn_neuron)
            cur_attn_neuron = attn_head * HEAD_DIM + attn_neuron
            attn_neuron_key = extractor.model.vision_model.encoder.layers[attn_layer].self_attn.v_proj.weight.data[
                              cur_attn_neuron, :]
            all_contribution = all_last_attn_subvalues[attn_layer][attn_head][0][attn_neuron]
            revert = 1
            if all_contribution < 0:
                print("negative contribution")
                revert = -1
            ln = extractor.model.vision_model.encoder.layers[attn_layer].layer_norm1
            x = torch.tensor(all_pos_layer_input[attn_layer][0]).to(device)
            u = x.mean(-1, keepdim=True)
            s = (x - u).pow(2).mean(-1, keepdim=True)
            u, s = u.reshape(-1), s.reshape(-1)
            head_neuron_dic = {}
            for layer_i in range(attn_layer):
                cur_layer_input = torch.tensor(all_pos_coefficient_scores[layer_i][0]).to(device)
                cur_layer_recompute = extractor.model.vision_model.encoder.layers[
                                          layer_i].mlp.fc2.weight.data * cur_layer_input.unsqueeze(0)
                for neuron_index in range(cur_layer_input.size(0)):
                    cur_layer_neurons = cur_layer_recompute[:, neuron_index]
                    cur_layer_neurons -= cur_layer_neurons.mean(dim=-1, keepdim=True)
                    cur_layer_neurons *= ln.weight
                    cur_layer_neurons_innerproduct = torch.sum(cur_layer_neurons * attn_neuron_key, dim=-1)
                    head_neuron_dic[str(layer_i) + "_" + str(neuron_index)] = cur_layer_neurons_innerproduct.item() * revert
            # cur_inner_all = torch.sum(ln(torch.tensor(all_pos_residual_output[attn_layer][attn_pos]).to(device))*ffn_neuron_key, -1)
            for layer_i in range(attn_layer):
                cur_layer_input = torch.tensor(all_pos_layer_input[layer_i]).to(device)
                cur_v_heads_recompute = torch.tensor(all_last_attn_subvalues[layer_i]).permute(1, 0, 2).to(device)
                cur_attn_o_split = extractor.model.vision_model.encoder.layers[layer_i].self_attn.out_proj.weight.data.view(
                    1, -1, HEAD_NUM, HEAD_DIM)
                cur_attn_o_recompute = cur_attn_o_split * (cur_v_heads_recompute.unsqueeze(1))
                cur_attn_o_recompute = cur_attn_o_recompute.permute(0, 2, 3, 1)
                cur_attn_o_recompute = cur_attn_o_recompute.sum(dim=0)
                for head_index in range(cur_attn_o_recompute.size(0)):
                    for neuron_index in range(cur_attn_o_recompute.size(1)):
                        cur_layer_neurons = cur_attn_o_recompute[head_index][neuron_index]
                        cur_layer_neurons -= cur_layer_neurons.mean(dim=-1, keepdim=True)
                        cur_layer_neurons *= ln.weight
                        cur_layer_neurons_innerproduct = torch.sum(cur_layer_neurons * attn_neuron_key, dim=-1)
                        head_neuron_dic[str(layer_i) + "_" + str(head_index) + "_" + str(
                            neuron_index)] = cur_layer_neurons_innerproduct.item() * revert
        head_neuron_list = dict_to_list(head_neuron_dic)
        head_neuron_list_sort = sorted(head_neuron_list, key=lambda x: x[-1])[::-1]
    print(f"top 10 query neurons for neuron {neuron}:")
    print(head_neuron_list_sort[:10])
