import copy
import json
import math
from functools import partial
from typing import Optional, Tuple, List, Any

import einops
import numpy as np
from scipy import stats
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from .patch_mlp import *
from .patch_attn import *

class BaseNeuroSynapticEdit:
    def __init__(self, model_name_or_path='gpt2', device='cuda', ig_method='ig'):
        self.baseline_activations = None
        self.model_name = model_name_or_path
        self.device = device
        self.ig_method = ig_method
        if "gpt2" in self.model_name:
            self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
            self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            self.transformer_layers_attr = "transformer.h"
            self.mlp_input = "mlp.c_fc"
            self.mlp_output = "mlp.c_proj.weight"
            self.word_embeddings_attr = "transformer.wpe"
            # Attention components - Q, K, V are part of the same matrix initially
            self.attn_input = "attn.c_attn"
            self.attn_output = "attn.c_proj"  # Projection weight for the output of attention mechanism
        elif 'gemma' in self.model_name:
            self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto")
            self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
            self.transformer_layers_attr = "model.layers"  # Access to the decoder layers
            self.mlp_input = "mlp.up_proj"  # Input projection in the MLP
            self.mlp_output = "mlp.down_proj.weight"  # Output projection in the MLP, assuming you need the weight
            self.word_embeddings_attr = "model.embed_tokens"  # Access to the token embedding layer
            # Attention components
            self.attn_Q = "self_attn.q_proj"  # Q projection in the attention mechanism
            self.attn_K = "self_attn.k_proj"
            self.attn_V = "self_attn.v_proj"
            self.attn_output = "self_attn.o_proj"  # Output projection in the attention mechanism
        elif 'Llama' or 'llama' in self.model_name:
            self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto")
            self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

            self.transformer_layers_attr = "model.layers"
            self.mlp_input = "mlp.gate_proj"  # or "mlp.up_proj", depending on the specific component you need
            self.mlp_output = "mlp.down_proj.weight"
            self.word_embeddings_attr = "model.embed_tokens"

            self.attn_Q = "self_attn.q_proj"  # Q projection in the attention mechanism
            self.attn_K = "self_attn.k_proj"
            self.attn_V = "self_attn.v_proj"
            self.attn_output = "self_attn.o_proj"
        self.hidden_size = self.model.config.hidden_size

    def _get_output_ff_layer(self, layer_idx):
        return get_ff_layer(
            self.model,
            layer_idx,
            transformer_layers_attr=self.transformer_layers_attr,
            ff_attrs=self.mlp_output,
        )

    def n_layers(self):
        return len(self._get_transformer_layers())
    def _get_transformer_layers(self):
        return get_attributes(self.model, self.transformer_layers_attr)

    def _prepare_inputs(self, query, ground_truth=None):
        # Remove the placeholder _X_ from the query for GPT models
        if 'MASK' in query:
            query = query.replace("[MASK] .", "").strip()
        elif '_X_.' in query:
            query = query.replace("_X_.", "").strip()
        elif '_X_ .' in query:
            query = query.replace("_X_ .", "").strip()
        # Encode the modified query
        encoded_input = self.tokenizer(query, return_tensors="pt").to(self.device)

        # For GPT models, the model is expected to predict the next word(s) after the query
        mask_idx = -1

        # Encode the target (ground_truth) if provided
        if ground_truth is not None:
            ground_truth_label = self.tokenizer.encode(ground_truth)
        else:
            ground_truth_label = None

        return encoded_input, mask_idx, ground_truth_label, query

    @staticmethod
    def _prepare_inputs_only_mask_idx_query(query):
        # Remove the placeholder _X_ from the query for GPT models
        if 'MASK' in query:
            query = query.replace("[MASK] .", "").strip()
        elif '_X_' in query:
            query = query.replace("_X_.", "").strip()
        elif '_X_ .' in query:
            query = query.replace("_X_ .", "").strip()
        mask_idx = -1
        return mask_idx, query

    def _get_answer_str(self, prompt, ground_truth):
        encoded_input, mask_idx, target_label, prompt = self._prepare_inputs(
            prompt, ground_truth
        )
        n_sampling_steps = len(target_label)
        argmax_completion_str = ""

        for i in range(n_sampling_steps):
            if i > 0:
                encoded_input, mask_idx, target_label, prompt = self._prepare_inputs(
                    prompt, ground_truth
                )
            outputs = self.model(**encoded_input)
            probs = F.softmax(outputs.logits[:, mask_idx, :], dim=-1)

            # get info about argmax completion
            argmax_prob, argmax_id = [i.item() for i in probs.max(dim=-1)]
            argmax_str = self.tokenizer.decode([argmax_id])

            prompt += argmax_str
            argmax_completion_str += argmax_str
        return argmax_completion_str,

    def _generate(self, query, ground_truth):
        encoded_input, mask_idx, ground_truth_label, query = self._prepare_inputs(query, ground_truth)
        n_sampling_steps = len(ground_truth_label)
        all_gt_probs = []
        all_argmax_probs = []
        argmax_tokens = []
        argmax_completion_str = ""

        for i in range(n_sampling_steps):
            if i > 0:
                # retokenize new inputs
                encoded_input, mask_idx, ground_truth_label, query = self._prepare_inputs(
                    query, ground_truth
                )
            outputs = self.model(**encoded_input)
            probs = F.softmax(outputs.logits[:, mask_idx, :], dim=-1)
            target_idx = ground_truth_label[i]
            gt_prob = probs[:, target_idx].item()
            all_gt_probs.append(gt_prob)

            # get info about argmax completion
            argmax_prob, argmax_id = [i.item() for i in probs.max(dim=-1)]
            argmax_tokens.append(argmax_id)
            argmax_str = self.tokenizer.decode([argmax_id])
            all_argmax_probs.append(argmax_prob)

            query += argmax_str
            argmax_completion_str += argmax_str

        gt_prob = math.prod(all_gt_probs) if len(all_gt_probs) > 1 else all_gt_probs[0]
        argmax_prob = (
            math.prod(all_argmax_probs)
            if len(all_argmax_probs) > 1
            else all_argmax_probs[0]
        )
        return gt_prob, argmax_prob, argmax_completion_str, argmax_tokens


    def generate_only_acc(self, query, ground_truth):
        encoded_input, mask_idx, ground_truth_label, query = self._prepare_inputs(
            query, ground_truth
        )
        n_sampling_steps = len(ground_truth_label)
        all_gt_probs = []

        for i in range(n_sampling_steps):
            if i > 0:
                # re-tokenize new inputs
                encoded_input, mask_idx, ground_truth_label, query = self._prepare_inputs(
                    query, ground_truth
                )
            outputs = self.model(**encoded_input)
            probs = F.softmax(outputs.logits[:, mask_idx, :], dim=-1)
            target_idx = ground_truth_label[i]
            gt_prob = probs[:, target_idx].item()
            all_gt_probs.append(gt_prob)

        gt_prob = math.prod(all_gt_probs) if len(all_gt_probs) > 1 else all_gt_probs[0]
        return gt_prob



    def scaled_input(self, activations=None, steps=20,
                     layer_idx=0, encoded_input=None, mask_idx=None):
        """
        Tiles activations along the batch dimension - gradually scaling them over
        `steps` steps from 0 to their original value over the batch dimensions.

        `activations`: torch.Tensor
        original activations
        `steps`: int
        number of steps to take
        """
        if self.ig_method == 'ig':
            device = activations.device

            diff_activations = activations
            tiled_activations = einops.repeat(diff_activations, "b d -> (r b) d", r=steps)
            # Ensure linspace is on the same device as activations
            linspace_tensor = torch.linspace(start=0, end=1, steps=steps, device=device)[:, None]

            out = tiled_activations * linspace_tensor
            return out
        elif self.ig_method == 'amig':
            replace_token_id = self.tokenizer.eos_token_id

            all_res = []

            # get original activations for the complete input
            _, original_activations = self.get_baseline_with_activations(encoded_input,
                                                                         layer_idx=layer_idx, mask_idx=mask_idx,
                                                                         layer_attr=self.mlp_input)

            for idx in range(encoded_input['input_ids'].size(1)):
                # create a copy of the input and replace the idx-th word with mask token
                masked_input = copy.deepcopy(encoded_input)
                masked_input['input_ids'][0][idx] = replace_token_id

                # get masked activations to use as baseline
                _, baseline_activations = self.get_baseline_with_activations(masked_input,
                                                                             layer_idx=layer_idx, mask_idx=mask_idx,
                                                                             layer_attr=self.mlp_input)
                step = (original_activations - baseline_activations) / steps  # (1, ffn_size)

                res = torch.cat([torch.add(baseline_activations, step * i) for i in range(steps)], dim=0)
                all_res.append(res)
            # average
            mean_res = torch.stack(all_res).mean(dim=0)
            return mean_res
        elif self.ig_method == 'sig':
            all_res = []
            # get original activations for the complete input
            _, original_activations = self.get_baseline_with_activations(encoded_input,
                                                                         layer_idx=layer_idx, mask_idx=mask_idx,
                                                                         layer_attr=self.mlp_input)
            for idx in range(encoded_input['input_ids'].size(1)):
                # create a copy of the input and replace the idx-th word with mask token
                masked_input = copy.deepcopy(encoded_input)
                masked_input['input_ids'][0][idx] = '[PAD]'

                # get masked activations to use as baseline
                _, baseline_activations = self.get_baseline_with_activations(masked_input,
                                                                             layer_idx=layer_idx, mask_idx=mask_idx,
                                                                             layer_attr=self.mlp_input)
                step = (original_activations - baseline_activations) / steps  # (1, ffn_size)

                res = torch.cat([torch.add(baseline_activations, step * i) for i in range(steps)], dim=0)
                all_res.append(res)
            # average
            mean_res = torch.stack(all_res).mean(dim=0)
            return mean_res
        else:
            raise NotImplementedError

    def get_baseline_with_activations(
            self, encoded_input, layer_idx, mask_idx, layer_attr
    ):
        """
        Gets the baseline outputs and activations for the unmodified model at a given index.

        `encoded_input`: torch.Tensor
            the inputs to the model from self.tokenizer.encode_plus()
        `layer_idx`: int
            which transformer layer to access
        `mask_idx`: int
        layer_attr: choose mlp or attn
        """

        def get_activations(model):
            """
            This hook function should assign the intermediate activations at a given layer / mask idx
            to the 'self.baseline_activations' variable
            """

            def hook_fn(acts):
                if mask_idx is not None:
                    self.baseline_activations = acts[:, mask_idx, :]
                else:
                    self.baseline_activations = acts[:, -1, :]

            return register_hook(
                model,
                layer_idx=layer_idx,
                f=hook_fn,
                transformer_layers_attr=self.transformer_layers_attr,
                # layer_attrs=self.mlp_input,
                layer_attrs=layer_attr,
            )

        handle = get_activations(self.model)
        baseline_outputs = self.model(**encoded_input)
        handle.remove()
        baseline_activations = self.baseline_activations
        self.baseline_activations = None
        return baseline_outputs, baseline_activations


    def get_scores_for_layer(
            self,
            prompt: str,
            ground_truth: str,
            layer_idx: int,
            batch_size: int = 10,
            steps: int = 20,
            layer_attr='mlp.c_fc',
    ):
        if self.ig_method == 'ig':
            assert steps % batch_size == 0
            n_batches = steps // batch_size

            # First we take the unmodified model and use a hook to return the baseline intermediate activations at our chosen target layer
            encoded_input, mask_idx, target_label, prompt = self._prepare_inputs(
                prompt, ground_truth,
            )

            # for autoregressive models, we might want to generate > 1 token
            n_sampling_steps = len(target_label)

            integrated_grads = []
            next_token_str = None
            for i in range(n_sampling_steps):
                if i > 0:
                    # re-tokenize new inputs
                    encoded_input, mask_idx, target_label, prompt = self._prepare_inputs(
                        prompt, ground_truth
                    )
                (
                    baseline_outputs,
                    baseline_activations,
                ) = self.get_baseline_with_activations(
                    encoded_input, layer_idx, mask_idx, layer_attr=layer_attr
                )
                if n_sampling_steps > 1:
                    argmax_next_token = (
                        baseline_outputs.logits[:, mask_idx, :].argmax(dim=-1).item()
                    )
                    next_token_str = self.tokenizer.decode(argmax_next_token)

                scaled_weights = self.scaled_input(activations=baseline_activations, steps=steps,)
                scaled_weights.requires_grad_(True)

                integrated_grads_this_step = []  # to store the integrated gradients

                for batch_weights in scaled_weights.chunk(n_batches):
                    inputs = {
                        "input_ids": einops.repeat(
                            encoded_input["input_ids"], "b d -> (r b) d", r=batch_size
                        ),
                        "attention_mask": einops.repeat(
                            encoded_input["attention_mask"],
                            "b d -> (r b) d",
                            r=batch_size,
                        ),
                    }

                    # then patch the model to replace the activations with the scaled activations
                    mlp_patch_layer(  # mode =replace
                        self.model,
                        layer_idx=layer_idx,
                        mask_idx=mask_idx,
                        replacement_activations=batch_weights,
                        transformer_layers_attr=self.transformer_layers_attr,
                        # ff_attrs=self.input_ff_attr,
                        ff_attrs=layer_attr,
                    )

                    # then forward through the model to get the logits
                    outputs = self.model(**inputs)

                    # then calculate the gradients for each step w/r/t the inputs
                    probs = F.softmax(outputs.logits[:, mask_idx, :], dim=-1)
                    target_idx = target_label[i]
                    grad = torch.autograd.grad(
                        torch.unbind(probs[:, target_idx]), batch_weights
                    )[0]
                    grad = grad.sum(dim=0)
                    integrated_grads_this_step.append(grad)

                    mlp_unpatch_layer(
                        self.model,
                        layer_idx=layer_idx,
                        transformer_layers_attr=self.transformer_layers_attr,
                        ff_attrs=layer_attr,
                    )

                # then sum, and multiply by W-hat / m
                integrated_grads_this_step = torch.stack(
                    integrated_grads_this_step, dim=0
                ).sum(dim=0)
                integrated_grads_this_step *= baseline_activations.squeeze(0) / steps
                integrated_grads.append(integrated_grads_this_step)

                if n_sampling_steps > 1:
                    prompt += next_token_str
            integrated_grads = torch.stack(integrated_grads, dim=0).sum(dim=0) / len(
                integrated_grads
            )
            return integrated_grads

        elif self.ig_method == 'amig':
            encoded_input, mask_idx, target_label, prompt = self._prepare_inputs(
                prompt, ground_truth,
            )
            n_sampling_steps = len(target_label)
            baseline_outputs, baseline_activations = self.get_baseline_with_activations(encoded_input, layer_idx,
                                                                                        mask_idx, layer_attr=self.mlp_input)

            n_batches = steps // batch_size

            # Initialize an accumulator for the Distilled Gradients
            D_accumulator = torch.zeros_like(baseline_activations.squeeze(0))
            next_token_str  = None
            for i in range(n_sampling_steps):
                if i > 0:
                    # retokenize new inputs
                    encoded_input, mask_idx, target_label, prompt = self._prepare_inputs(
                        prompt, ground_truth
                    )

                if n_sampling_steps > 1:
                    argmax_next_token = (
                        baseline_outputs.logits[:, mask_idx, :].argmax(dim=-1).item()
                    )
                    next_token_str = self.tokenizer.decode(argmax_next_token)
                scaled_weights = self.scaled_input(encoded_input=encoded_input, steps=steps, layer_idx=layer_idx,
                                                   mask_idx=mask_idx)
                scaled_weights.requires_grad_(True)

                integrated_grads_this_step = []  # to store the integrated gradients

                for batch_weights in scaled_weights.chunk(n_batches):
                    inputs = {
                        "input_ids": einops.repeat(
                            encoded_input["input_ids"], "b d -> (r b) d", r=batch_size
                        ),
                        "attention_mask": einops.repeat(
                            encoded_input["attention_mask"],
                            "b d -> (r b) d",
                            r=batch_size,
                        ),
                    }

                    # then patch the model to replace the activations with the scaled activations
                    mlp_patch_layer(
                        self.model,
                        layer_idx=layer_idx,
                        mask_idx=mask_idx,
                        replacement_activations=batch_weights,
                        transformer_layers_attr=self.transformer_layers_attr,
                        ff_attrs=self.mlp_input,
                    )

                    # then forward through the model to get the logits
                    outputs = self.model(**inputs)

                    # then calculate the gradients for each step w/r/t the inputs
                    probs = F.softmax(outputs.logits[:, mask_idx, :], dim=-1)
                    target_idx = target_label[i]
                    grad = torch.autograd.grad(torch.unbind(probs[:, target_idx]), batch_weights)[0]
                    grad = grad.sum(dim=0)
                    integrated_grads_this_step.append(grad)

                    mlp_unpatch_layer(
                        self.model,
                        layer_idx=layer_idx,
                        transformer_layers_attr=self.transformer_layers_attr,
                        ff_attrs=self.mlp_input,
                    )

                integrated_grads_this_step = torch.stack(
                    integrated_grads_this_step, dim=0
                ).sum(dim=0)
                integrated_grads_this_step *= baseline_activations.squeeze(0) / steps

                if n_sampling_steps > 1:
                    prompt += next_token_str
                D_accumulator += integrated_grads_this_step

            return D_accumulator

        elif self.ig_method == 'sig':
            encoded_input, mask_idx, target_label, prompt = self._prepare_inputs(
                prompt, ground_truth,
            )
            n_sampling_steps = len(target_label)
            baseline_outputs, baseline_activations = self.get_baseline_with_activations(encoded_input, layer_idx,
                                                                                        mask_idx, layer_attr=self.mlp_input)

            n_batches = steps // batch_size

            # Initialize an accumulator for the Distilled Gradients
            D_accumulator = torch.zeros_like(baseline_activations.squeeze(0))
            next_token_str  = None
            for i in range(n_sampling_steps):
                if i > 0:
                    # retokenize new inputs
                    encoded_input, mask_idx, target_label, prompt = self._prepare_inputs(
                        prompt, ground_truth
                    )

                if n_sampling_steps > 1:
                    argmax_next_token = (
                        baseline_outputs.logits[:, mask_idx, :].argmax(dim=-1).item()
                    )
                    next_token_str = self.tokenizer.decode(argmax_next_token)
                scaled_weights = self.scaled_input(encoded_input=encoded_input, steps=steps, layer_idx=layer_idx,
                                                   mask_idx=mask_idx)
                scaled_weights.requires_grad_(True)

                integrated_grads_this_step = []  # to store the integrated gradients

                for batch_weights in scaled_weights.chunk(n_batches):
                    inputs = {
                        "input_ids": einops.repeat(
                            encoded_input["input_ids"], "b d -> (r b) d", r=batch_size
                        ),
                        "attention_mask": einops.repeat(
                            encoded_input["attention_mask"],
                            "b d -> (r b) d",
                            r=batch_size,
                        ),
                    }

                    # then patch the model to replace the activations with the scaled activations
                    mlp_patch_layer(
                        self.model,
                        layer_idx=layer_idx,
                        mask_idx=mask_idx,
                        replacement_activations=batch_weights,
                        transformer_layers_attr=self.transformer_layers_attr,
                        ff_attrs=self.mlp_input,
                    )

                    # then forward through the model to get the logits
                    outputs = self.model(**inputs)

                    # then calculate the gradients for each step w/r/t the inputs
                    probs = F.softmax(outputs.logits[:, mask_idx, :], dim=-1)
                    target_idx = target_label[i]
                    grad = torch.autograd.grad(torch.unbind(probs[:, target_idx]), batch_weights)[0]
                    grad = grad.sum(dim=0)
                    integrated_grads_this_step.append(grad)

                    mlp_unpatch_layer(
                        self.model,
                        layer_idx=layer_idx,
                        transformer_layers_attr=self.transformer_layers_attr,
                        ff_attrs=self.mlp_input,
                    )

                integrated_grads_this_step = torch.stack(
                    integrated_grads_this_step, dim=0
                ).sum(dim=0)
                integrated_grads_this_step *= baseline_activations.squeeze(0) / steps

                if n_sampling_steps > 1:
                    prompt += next_token_str
                D_accumulator += integrated_grads_this_step

            return D_accumulator

        else:
            raise NotImplementedError



    def get_scores(
            self,
            prompt: str,
            ground_truth: str,
            batch_size: int = 10,
            steps: int = 20,
    ):
        """
        Gets the attribution scores for a given query and ground truth.
        `query`: str
            the query to get the attribution scores for
        `ground_truth`: str
            the ground truth / expected output
        `batch_size`: int
            batch size
        `steps`: int
            total number of steps (per token) for the integrated gradient calculations
        `attribution_method`: str
            the method to use for getting the scores. Choose from 'integrated_grads' or 'max_activations'.
        """
        scores = []
        # Assuming `self.device` is set to the primary device e.g., 'cuda:0'
        # but it's better to dynamically adjust based on where your initial tensors are.
        for layer_idx in range(self.n_layers()):
            layer_scores = self.get_scores_for_layer(
                prompt,
                ground_truth,
                layer_idx=layer_idx,
                batch_size=batch_size,
                steps=steps,
                layer_attr=self.mlp_input
            )
            # If scores are not already on 'cuda:0', ensure they are moved there.
            # It's assumed all operations can stay on the CUDA device until final aggregation.
            if layer_scores.device != self.device:
                layer_scores = layer_scores.to(self.device)
            scores.append(layer_scores)
        # Stack operations are kept on the same device as the individual scores.
        stacked_scores = torch.stack(scores)

        # 在这里返回完整的分数，然后可以被heatmap调用。
        return stacked_scores


    def get_one_query_neurons_complete(
            self,
            prompt: str,
            ground_truth: str,
            batch_size: int = 10,
            steps: int = 20,
            threshold: float = None,
            adaptive_threshold: float = None,
            percentile: float = None,
    ) -> List[List[int]]:
        attribution_scores = self.get_scores(
            prompt,
            ground_truth,
            batch_size=batch_size,
            steps=steps,
        )
        assert (
                sum(e is not None for e in [threshold, adaptive_threshold, percentile]) == 1
        ), f"Provide one and only one of threshold / adaptive_threshold / percentile"
        if adaptive_threshold is not None:
            threshold = attribution_scores.max().item() * adaptive_threshold
        if threshold is not None:
            return torch.nonzero(attribution_scores > threshold).cpu().tolist()
        else:
            s = attribution_scores.flatten().detach().cpu().numpy()
            return (
                torch.nonzero(attribution_scores > np.percentile(s, percentile))
                .cpu()
                .tolist()
            )


    def get_one_query_neurons(
            self,
            prompt: str,
            ground_truth: str,
            batch_size: int = 10,
            steps: int = 20,
            adaptive_threshold_neurons: float = None,
    ):
        attribution_scores = self.get_scores(
            prompt,
            ground_truth,
            batch_size=batch_size,
            steps=steps,
        )
        threshold = attribution_scores.max().item() * adaptive_threshold_neurons
        kns = torch.nonzero(attribution_scores > threshold).cpu().tolist()
        # 再来标记这些KNs的平均激活分数。
        scores_of_kns = torch.tensor([attribution_scores[i, j] for i, j in kns])
        average_score = scores_of_kns.mean().item()
        return kns, average_score

    def compute_attributions(self, prompts, ground_truths, batch_size=1, steps=50):
        all_scores = []
        for prompt, ground_truth in zip(prompts, ground_truths):
            scores = self.get_scores(prompt, ground_truth, batch_size, steps)
            all_scores.append(scores)

        # Stack scores along a new dimension (0th dimension), results in shape [num_prompts, 12, 3072]
        stacked_scores = torch.stack(all_scores)
        return stacked_scores

    def evaluate_neurons(self, stacked_scores, alpha, beta):
        # Compute mean and std along the dimension 0 (across all prompts)
        mean_scores = torch.mean(stacked_scores, dim=0)
        std_scores = torch.std(stacked_scores, dim=0)

        S = alpha * mean_scores - beta * std_scores

        # Get the indices where S > 0
        selected_neurons = torch.nonzero(S > 0, as_tuple=True)
        return selected_neurons

    def _classify_position_in_QKV(self, position_idx):
        segment = position_idx // self.hidden_size
        if segment == 0:
            return "Q"
        elif segment == 1:
            return "K"
        elif segment == 2:
            return "V"
        else:
            raise ValueError("Position index out of bounds for QKV matrix.")


    def get_one_query_synapses_IG(self, prompt: str, ground_truth: str, batch_size: int = 10,
                               steps: int = 20, adaptive_threshold: float = None):
        """暂时因为QKV的物理意义不明确，不使用IG"""
        attribution_scores = self.get_scores(
            prompt,
            ground_truth,
            batch_size=batch_size,
            steps=steps,
        )
        threshold = attribution_scores.max().item() * adaptive_threshold
        significant_positions = torch.nonzero(attribution_scores > threshold).cpu().tolist()

        # Classify each significant position as belonging to Q, K, or V
        classified_synapses = []
        src_synapses = []
        for layer_idx, position_idx in significant_positions:
            synapse_type = self._classify_position_in_QKV(position_idx % (3 * self.hidden_size))
            # Use modulus to adjust position index within 3*H range
            classified_synapses.append({synapse_type:[layer_idx, position_idx]})
            src_synapses.append([layer_idx, position_idx])

        return classified_synapses, src_synapses


    def _modify_neurons(
            self,
            query: str,
            ground_truth: str,
            neurons: List[List[int]],
            mode: str = "suppress",
            undo_modification: bool = True,
            layer_attr='mlp.c_fc',
    ):
        results_dict = {}
        mask_idx, query= self._prepare_inputs_only_mask_idx_query(query)
        (
            gt_baseline_prob,
            argmax_baseline_prob,
            argmax_completion_str,
            _,
        ) = self._generate(query, ground_truth)
        results_dict["before"] = {
            "gt_prob": gt_baseline_prob,
            "argmax_completion": argmax_completion_str,
            "argmax_prob": argmax_baseline_prob,
        }

        # patch model to suppress neurons
        # store all the layers we patch, so we can unpatch them later
        all_layers = set([n[0] for n in neurons])

        mlp_patch_layer(
            self.model,
            mask_idx,
            mode=mode,
            neurons=neurons,
            transformer_layers_attr=self.transformer_layers_attr,
            ff_attrs=layer_attr,
        )

        # get the probabilities of the ground_truth being generated +
        # the argmax / greedy completion after modifying the activations
        new_gt_prob, new_argmax_prob, new_argmax_completion_str, _ = self._generate(
            query, ground_truth
        )
        results_dict["after"] = {
            "gt_prob": new_gt_prob,
            "argmax_completion": new_argmax_completion_str,
            "argmax_prob": new_argmax_prob,
        }
        results_dict['prob_change'] = (new_gt_prob - gt_baseline_prob) / gt_baseline_prob if gt_baseline_prob != 0 else 0

        unpatch_fn = partial(
            mlp_unpatch_layers,
            model=self.model,
            layer_indices=all_layers,
            transformer_layers_attr=self.transformer_layers_attr,
            ff_attrs=layer_attr,
        )

        if undo_modification:
            unpatch_fn()
            unpatch_fn = lambda *args: args

        return results_dict

    def _modify_synapses_patch(self, query, ground_truth, synapses: List[List[int]], mode='suppress'):
        """
        FIXME
        """
        # coefficient = coefficient if mode == "suppress" else 1.0 / coefficient
        results_dict = {}
        (
            gt_baseline_prob,
            argmax_baseline_prob,
            argmax_completion_str,
            _,
        ) = self._generate(query, ground_truth)
        results_dict["before"] = {
            "gt_prob": gt_baseline_prob,
            "argmax_completion": argmax_completion_str,
            "argmax_prob": argmax_baseline_prob,
        }
        tmp = self.model
        attn_patch_layer(
            model=self.model,
            mode=mode,
            synapses=synapses,
            transformer_layers_attr=self.transformer_layers_attr,
            layer_attrs=self.attn_input,
            # layer_attrs=self.attn_output,
        )

        new_gt_prob, new_argmax_prob, new_argmax_completion_str, _ = self._generate(
            query, ground_truth
        )
        results_dict["after"] = {
            "gt_prob": new_gt_prob,
            "argmax_completion": new_argmax_completion_str,
            "argmax_prob": new_argmax_prob,
        }
        results_dict['prob_change'] = (new_gt_prob - gt_baseline_prob) / gt_baseline_prob if gt_baseline_prob!=0 else 0
        unpatch_fn = partial(
            attn_unpatch_layers,
            model=self.model,
            layer_indices=set([s[0] for s in synapses]),
            transformer_layers_attr=self.transformer_layers_attr,
            layer_attrs=self.attn_input,
            # layer_attrs=self.attn_output,
        )
        unpatch_fn()
        return results_dict
    def _modify_synapses(self, query, ground_truth, synapses: List[List[int]], mode='suppress'):
        """
        Modify the values of specific column vectors in the QKV matrices for multiple synapses
        by multiplying them by a coefficient.
        """
        # coefficient = coefficient if mode == "suppress" else 1.0 / coefficient
        results_dict = {}
        (
            gt_baseline_prob,
            argmax_baseline_prob,
            argmax_completion_str,
            _,
        ) = self._generate(query, ground_truth)
        results_dict["before"] = {
            "gt_prob": gt_baseline_prob,
            "argmax_completion": argmax_completion_str,
            "argmax_prob": argmax_baseline_prob,
        }
        inputs = self.tokenizer(query, return_tensors="pt").to(self.device)

        # Ensure the model returns attention scores
        outputs = self.model(**inputs, output_attentions=True)

        # Extract attention scores from the model outputs
        attention_scores = outputs.attentions  # tuple(num_layers): Tensor [batch_size, num_heads, seq_len, seq_len]

        for idx, layer_idx, head_idx in synapses:
            # Access the specific attention matrix for modification
            attn_matrix = attention_scores[layer_idx][0][head_idx]

            # Apply the modification: enhancing or suppressing attention for the token at idx
            if mode == 'suppress':
                attn_matrix[idx, :] = 0.0  # Modify scores influencing the token at 'idx'
                # attn_matrix[:, idx] = 0.0 # Modify scores influenced by the token at 'idx' todo
            elif mode == 'enhance':
                attn_matrix[idx, :] *= 2.0  # Modify scores influencing the token at 'idx'
                # attn_matrix[:, idx] *= 2.0 # Modify scores influenced by the token at 'idx'
            else:
                raise NotImplementedError
        new_gt_prob, new_argmax_prob, new_argmax_completion_str, _ = self._generate(
            query, ground_truth
        )
        results_dict["after"] = {
            "gt_prob": new_gt_prob,
            "argmax_completion": new_argmax_completion_str,
            "argmax_prob": new_argmax_prob,
        }
        results_dict['prob_change'] = (new_gt_prob - gt_baseline_prob) / gt_baseline_prob if gt_baseline_prob != 0 else 0
        return results_dict


    def get_predict_acc(self, query, ground_truth, neurons, mode='suppress',layer_attr='mlp.c_fc'):
        mask_idx, query= self._prepare_inputs_only_mask_idx_query(query)

        # MlpPatch the model to suppress neurons
        mlp_patch_layer(
            self.model,
            mask_idx,
            mode=mode,
            neurons=neurons,
            transformer_layers_attr=self.transformer_layers_attr,
            ff_attrs=layer_attr,
        )

        # Get the probability of the ground truth
        new_gt_prob = self.generate_only_acc(query, ground_truth)
        unpatch_fn = partial(
            mlp_unpatch_layers,
            model=self.model,
            layer_indices=set([n[0] for n in neurons]),
            transformer_layers_attr=self.transformer_layers_attr,
            ff_attrs=layer_attr,
        )
        unpatch_fn()
        unpatch_fn = lambda *args: args

        return new_gt_prob

    def _get_model_output(self, query, max_new_tokens=20, device='cuda', neurons=None, layer_attr=None, mode=None):
        if neurons is not None:
            # patch model to suppress neurons
            # store all the layers we patch, so we can unpatch them later
            all_layers = set([n[0] for n in neurons])
            mlp_patch_layer(
                self.model,
                mask_idx=-1,
                mode=mode,
                neurons=neurons,
                transformer_layers_attr=self.transformer_layers_attr,
                ff_attrs=layer_attr,
            )

            inputs = self.tokenizer(query, return_tensors="pt").to(device)
            prompt_length = inputs["input_ids"].shape[-1]

            outputs = self.model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                pad_token_id=self.tokenizer.eos_token_id,
                max_new_tokens=max_new_tokens
            )
            unpatch_fn = partial(
                mlp_unpatch_layers,
                model=self.model,
                layer_indices=all_layers,
                transformer_layers_attr=self.transformer_layers_attr,
                ff_attrs=layer_attr,
            )
            unpatch_fn()
            unpatch_fn = lambda *args: args
            predicted_answer = self.tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True)

        else:
            inputs = self.tokenizer(query, return_tensors="pt").to(device)
            prompt_length = inputs["input_ids"].shape[-1]

            outputs = self.model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                pad_token_id=self.tokenizer.eos_token_id,
                max_new_tokens=max_new_tokens
            )
            # Decode only the new tokens generated by the model
            predicted_answer = self.tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True)
        return predicted_answer


    def check_answer(self, query, ground_truth, neurons=None, layer_attr=None, mode=None):
        """return true or false"""
        predicted_answer = self._get_model_output(query=query, max_new_tokens=20,
                                                  neurons=neurons, layer_attr=layer_attr, mode=mode)
        return ground_truth in predicted_answer




    def get_attention_weights_for_one_query(self, query):
        inputs = self.tokenizer(query, return_tensors="pt").to(self.device)
        outputs = self.model(**inputs, output_attentions=True)
        attentions = outputs.attentions  # tuple of tensors per layer
        return attentions

    def get_knowledge_synapses_one_query(self, attention_matrices, query, adaptive_threshold_synapses=0.5):
        """
        知识突触的意义：对attention_matrices，这是一个tuple，考虑其一个layer对应的矩阵，它的列向量是知识突触。
        A的一个矩阵，大小为(batch_size, head_num, i, j)，其中，i，j代表了第i个token对第j个token的注意程度，那么列向量的和，最高的值就代表了最受注意的token。
        """
        synapse_locations = []
        global_avg_attention = torch.mean(
            torch.stack([attn.mean(dim=0) for layer in attention_matrices for attn in layer]))
        threshold = adaptive_threshold_synapses * global_avg_attention.item()

        for layer_idx, layer in enumerate(attention_matrices):
            # This simplifies the attention matrix to [num_heads, seq_len, seq_len].
            layer_attention_squeezed = layer.squeeze(0)
            for head_idx, attn in enumerate(layer_attention_squeezed):
                avg_attention = attn.squeeze().mean(dim=0)  # attn: [batch_size, seq_len, seq_len]
                column_scores = torch.sum(avg_attention, dim=0)

                # Identify columns (tokens) that exceed the threshold
                significant_idxs = torch.where(column_scores > threshold)[0]
                for idx in significant_idxs:
                    if idx < len(query):  # Ensure idx is within the range of query tokens
                        synapse_locations.append([idx.item(), layer_idx, head_idx])

        return synapse_locations

    @staticmethod
    def get_most_attended_token(attention_matrices, query):
        """
        For a single query, identify the token that receives the most attention on average,
        """
        # Average attention across heads using PyTorch, keeping computation on the GPU
        avg_attention = torch.mean(torch.stack([attn.squeeze() for attn in attention_matrices]), dim=0)

        # Sum attention scores across all rows to find the column (token) with the highest average attention
        column_scores = torch.sum(avg_attention, dim=0)
        most_attended_idx = torch.argmax(column_scores).item()  # Use .item() to convert the result to a Python int

        return query[most_attended_idx], most_attended_idx

    def analyze_across_queries(self, queries):
        """
        """

        attended_positions = []
        attended_tokens = []
        for query in queries:
            attentions = self.get_attention_weights_for_one_query(query)
            token, idx = self.get_most_attended_token(attentions, query.split())
            attended_positions.append(idx)
            attended_tokens.append(token)

        # Determine if the position is static or dynamic
        position_static = len(set(attended_positions)) == 1
        # Determine if the token identity is consistent
        token_consistent = len(set(attended_tokens)) == 1
        # Return the analysis results
        return position_static,token_consistent

    @staticmethod
    def calculate_consistency_ratio(neurons_3d: List[List[List[int]]]) -> float:
        """
        Calculate the Consistency Ratio (CR) for a set of neurons activated by different queries.

        :param neurons_3d: A three-dimensional list of neurons. Each sub-list represents the neurons activated for one query.
        :return: Consistency Ratio (CR) for the neurons.
        """
        # Flatten the list of neurons for each query and convert to sets for easy comparison
        neuron_sets = [set(tuple(neuron) for neuron in query_neurons) for query_neurons in neurons_3d]

        # Calculate intersection and union across all sets of neurons
        intersection = set.intersection(*neuron_sets)
        union = set.union(*neuron_sets)

        # Return the Consistency Ratio (CR)
        return len(intersection) / len(union) if union else 0




def main():
    pass


if __name__ == "__main__":
    main()