import torch
from tqdm import tqdm
import re
import numpy as np
import matplotlib.pyplot as plt


def get_token_indices(text, tokenizer):
    """
    根据输入文本和 tokenizer 返回每个单词对应的 token 索引列表。

    Args:
        text (str): 输入文本。
        tokenizer: HuggingFace 的 tokenizer 对象。

    Returns:
        token_indices_list (list): 每个单词在 inputs["input_ids"] 中的位置列表。
                                   [[list_of_indices_for_word_1], [list_of_indices_for_word_2], ...]
    """
    # 使用 tokenizer 对文本进行编码
    inputs = tokenizer(
        text,
        return_tensors="pt",
        return_offsets_mapping=True,
        is_split_into_words=False,
        add_special_tokens=True,
    )
    
    # 获取 offset_mapping 和 token ID
    offset_mapping = inputs["offset_mapping"].squeeze().tolist()
    input_ids = inputs["input_ids"].squeeze().tolist()
    
    # 将 token ID 转换为 token
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # 将输入拆分为单词并手动映射
    words = text.split()  # 简单地将文本按空格分词
    token_indices_list = []
    word_pointer = 0
    current_word_tokens = []

    for idx, offset in enumerate(offset_mapping):
        if offset == [0, 0]:  # 忽略特殊 token
            continue
        while word_pointer < len(words) and offset[0] >= len(" ".join(words[:word_pointer + 1])):
            if current_word_tokens:  # 保存当前单词的 token 索引
                token_indices_list.append(current_word_tokens)
            word_pointer += 1
            current_word_tokens = []

        current_word_tokens.append(idx)  # 将当前 token 索引加入当前单词

    if current_word_tokens:  # 保存最后一个单词的 token 索引
        token_indices_list.append(current_word_tokens)
    
    return token_indices_list

def get_token_indices_multiple(text, target_words, tokenizer, target_positions=None):
    """
    返回多个目标单词在分词结果中的 token 索引，可指定具体的目标位置。

    参数:
        - text: 原始文本 (str)
        - target_words: 目标单词列表 (list[str])
        - tokenizer: 分词器对象
        - target_positions: 指定目标单词的具体位置列表 (list[int])，与 target_words 对应。
            如果为 None，则匹配所有出现的目标单词。
    
    返回:
        - token_indices_dict: 包含目标单词及其对应 token 索引的字典 (dict)
            格式为 {target_word: [list_of_indices]}。
    """
    # 分词并获取偏移量
    inputs = tokenizer(text, return_tensors="pt", return_offsets_mapping=True, add_special_tokens=True)
    input_ids = inputs["input_ids"]
    print("Tokens:", tokenizer.convert_ids_to_tokens(input_ids[0]))

    offset_mapping = inputs["offset_mapping"][0]  # [seq_len, 2]
    token_indices_dict = {}

    for idx, target_word in enumerate(target_words):
        if target_word not in token_indices_dict.keys():
            token_indices_dict[target_word] = []
        # 获取指定位置，如果未指定则查找所有位置
        specified_position = target_positions[idx] if target_positions else None
        
        # 查找目标单词在原始文本中的所有位置
        start_positions = []
        start_idx = 0
        while True:
            start_idx = text.find(target_word, start_idx)
            if start_idx == -1:
                break
            start_positions.append(start_idx)
            start_idx += len(target_word)  # 移动到下一个可能的起始位置
        
        if not start_positions:
            continue

        # 根据指定位置过滤
        if specified_position is not None:
            if specified_position > len(start_positions) or specified_position < 1:
                raise ValueError(f"指定的目标单词位置超出范围: {target_word}, 位置: {specified_position}")
            start_positions = [start_positions[specified_position - 1]]

        # 根据 offset_mapping 找到对应 token 索引
        token_indices = []
        for start_idx in start_positions:
            end_idx = start_idx + len(target_word)
            word_indices = []
            for idx, (start, end) in enumerate(offset_mapping):
                if end <= start_idx:
                    continue
                if start >= end_idx:
                    break
                word_indices.append(idx)
            token_indices.append(word_indices)

        token_indices_dict[target_word].extend(token_indices)

    return token_indices_dict


def get_hidden_states_by_layer(outputs, token_indices, layers=None):
    """
    提取模型输出中目标单词对应索引在每一层的隐藏状态。

    参数:
        - outputs: 模型的输出结果，需包含 `hidden_states`
        - token_indices: 目标单词及其对应 Token 索引的字典，格式为 {target_word: [list_of_indices]}
        - layers: 需要提取的层索引列表（默认为所有层）

    返回:
        - hidden_states_dict: 每层对应目标单词隐藏状态的字典
            格式为 {layer_idx: {target_word: [token_hidden_states]}}。
    """
    # 检查是否包含隐藏状态
    if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
        raise ValueError("Model outputs must include 'hidden_states'. Ensure output_hidden_states=True.")
    
    # 提取所有隐藏状态
    all_hidden_states = outputs.hidden_states  # [num_layers, batch_size, seq_len, hidden_dim]
    num_layers = len(all_hidden_states)

    # 如果未指定层，则默认为所有层
    if layers is None:
        layers = range(num_layers)
    
    # 初始化结果字典
    hidden_states_dict = {layer: {} for layer in layers}

    # 遍历每一层和每个目标单词
    for layer_idx in layers:
        layer_hidden_states = all_hidden_states[layer_idx][0]  # [seq_len, hidden_dim]
        for target_word, index_lists in token_indices.items():
            word_hidden_states = []
            for indices in index_lists:  # 遍历单词可能的索引列表
                token_vectors = [layer_hidden_states[token_idx].detach().cpu() for token_idx in indices]
                word_hidden_states.append(torch.stack(token_vectors, dim=0))  # 每个索引对应的向量堆叠
            hidden_states_dict[layer_idx][target_word] = word_hidden_states
    
    return hidden_states_dict


# def construct_selfie_prompt(token_indices, term="fragment"):
    """
    动态构建 SelfIE Prompt，同时支持单个或多个目标，并支持一个目标有多个位置的情况。

    参数:
        - token_indices: {target_word: [list_of_indices]} 格式的目标索引字典
        - term: 用于描述被替换内容的术语 (str, 默认值: "fragment")

    返回:
        - selfie_prompt: 构建的 SelfIE 解释任务 Prompt
        - relationship_prompt: 构建的关系讨论 Prompt（如果有）
        - placeholder_map: {target_word: [list_of_placeholders]} 的映射字典
    """
    placeholder_map = {}
    placeholder_list = []
    relationship_prompt = None
    placeholder_count = 0  # 用于生成占位符

    # 构造占位符映射
    for target_word, indices_list in token_indices.items():
        placeholders = []
        for _ in indices_list[0]:
            placeholder = f"<unused{placeholder_count}>"
            placeholders.append(placeholder)
            placeholder_count += 1
        placeholder_map[target_word] = placeholders
        placeholder_list.extend(placeholders)

    a = []
    for i, word in enumerate(token_indices.keys()):
        a.append(f"".join(placeholder_map[word]))
    if len(a) > 1:
        placeholder_lines = " and".join(a)
    else:
        placeholder_lines = "".join(a)
    # 构建 SelfIE Prompt
    # placeholder_lines = "\n".join(
    #     [f"{i + 1}. {''.join(placeholder_map[word])}" for i, word in enumerate(token_indices.keys())]
    # )
    # placeholder_lines = "".join(
    #     [f"{}. {''.join(placeholder_map[word])}" for i, word in enumerate(token_indices.keys())]
    # )
    # selfie_prompt = f"What are the meanings of the this {term}s?{placeholder_lines}"
    # selfie_prompt = f"What are the meanings of the this{placeholder_lines}?"

    selfie_prompt = f"{placeholder_lines}. What did you see?"
    # selfie_prompt = f"What is the first order mind of the{placeholder_lines}?"
    # selfie_prompt = f"{placeholder_lines}. what do you think?"



    # # 构建关系讨论 Prompt
    # if len(placeholder_list) > 1:
    #     relationship_prompt = (
    #         f"What is the semantic relationship between these {term}s?\n"
    #         f"{' and '.join([' '.join(placeholder_map[word]) for word in token_indices.keys()])}"
    #     )
    if len(placeholder_list) > 1:
        relationship_prompt = (
            f"What is the relationship between these{placeholder_lines}?"
        )

    return selfie_prompt, relationship_prompt, placeholder_map




def generate_all_insert_infos_combined(hidden_states_dict, unused_indices, replacing_mode='normalized', overlay_strength=1.0, inputs_position=1, mode = "normal", second_hidden = None):
    """
    根据 hidden_states_dict 和 token_to_unused_map 生成合并后的 all_insert_infos。
    同一层的插入信息合并到一个字典中。

    参数:
        - hidden_states_dict: 每层对应目标单词隐藏状态的字典。
            格式为 {layer_idx: {target_word: [token_hidden_states]}}。
        - unused_indices: 插入位置的字典，格式为 {'<unused0>': [[pos1]], '<unused1>': [[pos2]]}。
        - replacing_mode: 插入模式（默认 'normalized'）。
        - overlay_strength: 插入权重（默认 1.0）。

    返回:
        - all_insert_infos: 包含插入操作信息的列表，每个元素是一个字典。
    """

    all_insert_infos = []
    sorted_values = [unused_indices[key][0] for key in sorted(unused_indices.keys(), key=lambda k: int(re.search(r'\d+', k).group()))]
    sorted_layers = sorted(hidden_states_dict.keys())

    if mode == "time_diff" and second_hidden == None:
        raise ValueError("second_hidden must be provided in time_diff mode.")
    if mode == "meaning_diff" and second_hidden == None:
        raise ValueError("second_hidden must be provided in meaning_diff mode.")

    # 按层处理插入信息
    for layer_idx, target_words_dict in hidden_states_dict.items():
        layer_insert_info = {
            'replacing_mode': replacing_mode,
            'overlay_strength': overlay_strength,
            inputs_position: []  # 初始化该层的插入信息
        }

        current_index = sorted_layers.index(layer_idx)

        if mode == "diff" and current_index + 1 < len(sorted_layers):   # 比较不同层的隐藏状态
            first_vector = next(iter(target_words_dict.values()))[0]
            next_layer_idx = sorted_layers[current_index + 1]
            second_vector = next(iter(hidden_states_dict[next_layer_idx].values()))[0]
            ls, _ = first_vector.size()
            for row_idx, (first_row, second_row) in enumerate(zip(first_vector, second_vector)):
                sec_inx = row_idx + ls
                layer_insert_info[inputs_position].append((sorted_values[row_idx], first_vector[row_idx:row_idx+1,:]))
                layer_insert_info[inputs_position].append((sorted_values[sec_inx], second_vector[row_idx:row_idx+1,:]))

        elif mode == "time_diff" :   # 比较同一层的不同时间点的隐藏状态
            first_vector = next(iter(target_words_dict.values()))[0]
            second_vector = next(iter(second_hidden[layer_idx].values()))[0]
            ls, _ = first_vector.size()
            for row_idx, (first_row, second_row) in enumerate(zip(first_vector, second_vector)):
                sec_inx = row_idx + ls
                layer_insert_info[inputs_position].append((sorted_values[row_idx], first_vector[row_idx:row_idx+1,:]))
                layer_insert_info[inputs_position].append((sorted_values[sec_inx], second_vector[row_idx:row_idx+1,:]))

        elif mode == "meaning_diff" :   # 计算向量差
            first_vector = next(iter(target_words_dict.values()))[0]
            second_vector = next(iter(second_hidden[layer_idx].values()))[0]
            ls, _ = first_vector.size()
            project_vector = project_(first_vector,second_vector)
            sacle_pr = torch.norm(second_vector, p=2) / torch.norm(project_vector, p=2)
            print("sacle_pr",sacle_pr)
            # diff_vector = second_vector - first_vector
            diff_vector = second_vector - sacle_pr * project_vector
            # diff_vector = project_vector - second_vector
            sacle = torch.norm(first_vector, p=2) / torch.norm(diff_vector, p=2)
            print("diff sacle",sacle)
            for row_idx, (first_row, second_row) in enumerate(zip(first_vector, second_vector)):
                layer_insert_info[inputs_position].append((sorted_values[row_idx], sacle * diff_vector))
            
        else:
        # 遍历该层的目标单词
            i = 0
            for target_word, token_hidden_states in target_words_dict.items():
                ls, _ = token_hidden_states[0].size()
                for j in range(ls):
                    layer_insert_info[inputs_position].append((sorted_values[i], token_hidden_states[0][j:j+1,:]))
                    i += 1
            
        # 如果该层有插入信息，则添加到 all_insert_infos
        if layer_insert_info[inputs_position]:
            all_insert_infos.append(layer_insert_info)

    return all_insert_infos

def project_(vector_a: torch.Tensor, vector_b: torch.Tensor):
    """
    Calculates the projection of vector_b onto vector_a and explains the resulting projection vector.

    Args:
        vector_a: The torch.Tensor representing the vector onto which to project (the base vector).
        vector_b: The torch.Tensor representing the vector being projected.
        explanation_function: A callable (function) that takes a torch.Tensor (a single vector)
                            as input and returns an explanation of that vector.

    Returns:
            - The projection of vector_b onto vector_a (torch.Tensor).

    """

    # Ensure vectors are flattened to handle potential multi-dimensional tensors
    vector_a = vector_a.flatten()
    vector_b = vector_b.flatten()
    magnitude_squared_a = torch.sum(vector_a ** 2)

    

    # Handle the case where vector_a is a zero vector to avoid division by zero
    if magnitude_squared_a > 0:
        unit_vector_a = vector_a / torch.sqrt(magnitude_squared_a)
        projection = torch.dot(vector_b, unit_vector_a) * unit_vector_a
    else:
        projection_vector = torch.zeros_like(vector_a)

    return projection.unsqueeze(0)


def hidden_states_to_insert_infos(hidden_states, unused_indices, replacing_mode='normalized', overlay_strength=1.0, inputs_position=40):
    sz, sentence_len, _ = hidden_states.size()
    if len(unused_indices) != sentence_len-1:
        raise ValueError("Length of unused_indices must be equal to sentence length minus 1.")
    
    sorted_values = [unused_indices[key][0] for key in sorted(unused_indices.keys(), key=lambda k: int(re.search(r'\d+', k).group()))]

    layer_insert_info = {
                'replacing_mode': replacing_mode,
                'overlay_strength': overlay_strength,
                inputs_position: []  # 初始化该层的插入信息
            }
    for _idx, index_value in enumerate(sorted_values):
        layer_insert_info[inputs_position].append((index_value, hidden_states[:,_idx+1,:]))

    return layer_insert_info
    

def build_sentence_with_prompts(index_list, tokenizer, prompt="", prompt2=""):
    # Initialize an empty string for the sentence
    sentence = ""
    sentence_list = []
    # Loop through the index list to construct the unused tokens without spaces
    for indices in index_list:
        for index in indices:
            # Add each <unusedx> token to the sentence
            sentence += f"<unused{index}>"
            sentence_list.append(f"<unused{index}>")
    
    # Concatenate the prompts and the constructed sentence
    full_sentence = f"{prompt}{sentence}{prompt2}"

    res = get_token_indices_multiple(full_sentence, sentence_list, tokenizer)
    
    return res, full_sentence

def build_sentence_with_prompts_modechat(index_list, tokenizer, prompt="<start_of_turn>user\n", prompt2="<start_of_turn>model\nSure, I will repeat for you:"):
    # Initialize an empty string for the sentence
    sentence = ""
    sentence_list = []
    # Loop through the index list to construct the unused tokens without spaces
    for indices in index_list:
        for index in indices:
            # Add each <unusedx> token to the sentence
            sentence += f"<unused{index}>"
            sentence_list.append(f"<unused{index}>")
    
    # Concatenate the prompts and the constructed sentence
    full_sentence = f"{prompt}{sentence}" + "<end_of_turn>\n" + f"{prompt2}"

    res = get_token_indices_multiple(full_sentence, sentence_list, tokenizer)
    
    return res, full_sentence

def build_words_with_prompts_modechat(index_dict, tokenizer, prompt="<start_of_turn>user\n", prompt2="<start_of_turn>model\n", mode="normal"):
    # Initialize an empty string for the sentence
    sentence = ""
    sentence_list = []
    placeholder_count = 0
    placeholders = []
    
    # 判断index_dict长度是否为1
    if len(index_dict) == 1: 
        if mode == "normal" or mode == "meaning_diff":
            for key, value in index_dict.items():
                for i, index in enumerate(value[0]):
                    sentence += f"<unused{i}>"
                    sentence_list.append(f"<unused{i}>")
        else:  #mode == "multilayers compare"
            for key, value in index_dict.items():
                _word = ""
                _word2 = ""
                _len = len(value[0])
                for _ in value[0]:
                    _word += f"<unused{placeholder_count}>"
                    _word2 += f"<unused{placeholder_count + _len}>"
                    sentence_list.append(f"<unused{placeholder_count}>")
                    sentence_list.append(f"<unused{placeholder_count + _len}>")
                    placeholder_count += 1
                placeholders.append(_word)
                placeholders.append(_word2)
            # 对sentence_list进行排序，里面元素为<unusedx>，x为数字
            sentence_list.sort(key=lambda x: int(re.search(r'\d+', x).group()))

    else:
        mode = "normal"
        for key, value in index_dict.items():
            _word = ""
            for _ in value[0]:
                _word += f"<unused{placeholder_count}>"
                sentence_list.append(f"<unused{placeholder_count}>")
                placeholder_count += 1
            placeholders.append(_word)
    
    if mode == "normal" and placeholders == []:
        full_sentence = f"{prompt}{sentence}" + "<end_of_turn>\n" + f"{prompt2}" + "Sure, I will repeat for you:"
    else:
        if mode == "normal":
            prompt3 = "What is the relationship between the vectors:"
            full_sentence = f"{prompt}" + f"{prompt3}" + f"{' and'.join(placeholders)}" + "<end_of_turn>\n" + f"{prompt2}"
        elif mode == "time_diff":
            # prompt3 = "What is the difference between these vectors:"
            prompt3 = "I will give you two vectors, you don’t need to decode them to any specific words."
#             prompt4 = '''Explain:
# 1. Key feature differences.
# 2. Semantic abstraction or concretization.
# 3. What does it mean if certain features remain unchanged?'''
            prompt4 = "What is the difference between their concepts?"
            full_sentence = f"{prompt}" + f"{prompt3}" + f"{' and'.join(placeholders)}" + prompt4 + "<end_of_turn>\n" + f"{prompt2}"
        else:
            # prompt3 = "I will give you a vector: "
            # prompt4 = " What is the meaning of the vector?"
            # full_sentence = f"{prompt}" + f"{prompt3}" + f"{sentence}" + prompt4 + "<end_of_turn>\n" + f"{prompt2}" + "Sure, I will summary for you:"

            full_sentence = f"{prompt}" + f"{sentence}" + "<end_of_turn>\n" + f"{prompt2}" + "Sure, I will explain it from my first impression with metaphor:"
            # Feel directly
            # Sure, I will explain the content within most relevant meaning:
            # Explain with first impression
            # Sure, I will explain the content without context:
            # I will explain with first impression with just only meaning
            # will make one sentence based on my first impression with this meaning:
            # will make Example based on my first impression with this meaning:
            # explain it from my first impression with just one meaning
            # explain it from my first impression with metaphor:


    res = get_token_indices_multiple(full_sentence, sentence_list, tokenizer)

    return res, full_sentence



def plot_token_hidden_states(hidden_states_dict1, target_word, save_path=None, hidden_states_dict2=None):
    """
    Plots the L2 norm of token_hidden_states for a target_word across layers.
    If two hidden_states_dict are provided, it computes and plots the difference in L2 norms.

    Parameters:
    hidden_states_dict1 (dict): The first input dictionary of the form 
                                {layer_idx: {target_word: [token_hidden_states]}}.
    target_word (str): The target word whose token_hidden_states to plot.
    save_path (str, optional): Path to save the plot. If None, the plot is not saved.
    hidden_states_dict2 (dict, optional): The second input dictionary for L2 norm difference computation.

    Returns:
    None: Displays the plot and optionally saves it.
    """
    def tensor_to_numpy(tensor):
        """Convert a PyTorch tensor to a NumPy array."""
        if isinstance(tensor, torch.Tensor):
            return tensor.to(dtype=torch.float32).cpu().numpy()
        return tensor

    def extract_l2_norms(hidden_states_dict):
        """Extract L2 norms for the target_word across layers."""
        layers = []
        l2_norms = []
        for layer_idx, words_dict in hidden_states_dict.items():
            if target_word in words_dict:
                token_hidden_states = words_dict[target_word]
                # Ensure conversion if token_hidden_states is a list or nested structure
                if isinstance(token_hidden_states, list):
                    token_hidden_states = np.array([tensor_to_numpy(t) for t in token_hidden_states])
                elif isinstance(token_hidden_states, torch.Tensor):
                    token_hidden_states = tensor_to_numpy(token_hidden_states)
                l2_norm = np.linalg.norm(token_hidden_states)
                layers.append(layer_idx)
                l2_norms.append(l2_norm)
        return np.array(layers), np.array(l2_norms)

    # Extract L2 norms for the first dictionary
    layers1, l2_norms1 = extract_l2_norms(hidden_states_dict1)

    if hidden_states_dict2:
        # Extract L2 norms for the second dictionary
        layers2, l2_norms2 = extract_l2_norms(hidden_states_dict2)

        # Ensure both dictionaries have the same layers for comparison
        if not np.array_equal(layers1, layers2):
            raise ValueError("Layer indices in the two hidden_states_dict do not match.")

        # Compute the difference in L2 norms
        l2_norms_diff = np.abs(l2_norms1 - l2_norms2)

        # Plot the difference
        plt.figure(figsize=(10, 6))
        plt.plot(layers1, l2_norms_diff, marker='o', label='L2 Norm Difference')
        plt.title(f"L2 Norm Difference of Token Hidden States for '{target_word}' Across Layers")
        plt.xlabel("Layer Index")
        plt.ylabel("L2 Norm Difference")
        plt.legend()
        plt.grid(True)

        # Save the plot if save_path is provided
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
            print(f"Plot saved to {save_path}")

        plt.show()
    else:
        # Plot the original L2 norms
        plt.figure(figsize=(10, 6))
        plt.plot(layers1, l2_norms1, marker='o', label='L2 Norm')
        plt.title(f"L2 Norm of Token Hidden States for '{target_word}' Across Layers")
        plt.xlabel("Layer Index")
        plt.ylabel("L2 Norm")
        plt.legend()
        plt.grid(True)

        # Save the plot if save_path is provided
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)
            print(f"Plot saved to {save_path}")

        plt.show()



def begin_2_interpret(all_insert_infos, bs, prompt, tokenizer, model, max_length=30):
    result = []
    for batch_start_idx in (range(0,len(all_insert_infos),bs)):
        with torch.no_grad():
            batch_insert_infos = all_insert_infos[batch_start_idx:min(batch_start_idx+bs, len(all_insert_infos))]
            batched_inputs = tokenizer([prompt] * len(batch_insert_infos), return_tensors="pt", add_special_tokens=True).to("cuda")
            repeat_prompt_n_tokens = batched_inputs['input_ids'][0].shape[-1]
            output = model.generate(batched_inputs.input_ids, max_length=max_length,insert_info=batch_insert_infos, output_attentions = False)
            cropped_interpretation_tokens = output[:,repeat_prompt_n_tokens:]
            cropped_interpretation = tokenizer.batch_decode(cropped_interpretation_tokens, skip_special_tokens=True)
            result.extend(cropped_interpretation)
            print(cropped_interpretation)
    return result