import os
DEVICE = "3"
os.environ["CUDA_VISIBLE_DEVICES"] = DEVICE
import argparse
import json
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
import numpy as np
import spacy
import matplotlib.pyplot as plt
from metrics import compute_exact_match
from utils_attention import *
from visualize_layers import draw_layer_scores
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#nohup python attention.py --model_path ../models/vicuna_7b_v1.5 --model_type vicuna_7b --data_path datasets_processed --setting_type concat --temperature 0.01 --hotpot --processed --zero_shot > test.out 2>&1 &

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.device_count() > 0:
        torch.cuda.manual_seed_all(seed)


def get_system_prompt(setting_type):
    # cred有不同的指示，但和我预想的不太一样
    if setting_type is not None and "cred" in setting_type:
        return "You are an assistant who can answer questions based on the given passages. Each passage has a credibility score that indicates the relevance and accuracy of the passage to the question. Your answer need to combine multiple passages and their credibility."
    else:
        return "You're a helpful AI assistant. The assistant answers questions based on given passages.\n"

def parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--model_type", type=str, required=True)
    parser.add_argument("--save_suffix", type=str,default=None)
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--temperature", type=float)
    parser.add_argument("--setting_type", type=str, default=None)
    # dataset
    parser.add_argument("--wikimulti", action='store_true')
    parser.add_argument("--hotpot", action='store_true')
    parser.add_argument("--musique", action='store_true')
    parser.add_argument("--wikiqa", action='store_true')
    parser.add_argument("--rgb", action='store_true')
    parser.add_argument("--evotemp", action='store_true')
    parser.add_argument("--misinfo", action='store_true')

    parser.add_argument("--qstart", action='store_true') # fastchat专用
    parser.add_argument("--parallel_size", type=int, default=1)
    parser.add_argument("--max_new_tokens", type=int, default=512)
    parser.add_argument("--zero_shot", action="store_true")

    # processed
    parser.add_argument("--processed", action="store_true")
    parser.add_argument("--result_suffix", type=str, default=None)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    if args.save_suffix is None:
        args.save_suffix = args.setting_type

    return args


def span_attention_one_token(output_att, n_layers,
        items,
        item_spans,
        context_span,
        marker_impstart,
        marker_impend,
        layer_span,
        threshold, use_norm=True, return_tokens=False):
    '''
    Obtain attention scores for grouping
    Parameters:
    ----------
    output_att : list of torch.Tensor
    n_layers: int
    items: list of str, different groups for context, like document, sentence
    item_spans : list of tuple, token span for each item
    context_span: tuple, token span for the entire context.
    maker_impstart: str, Marker indicating the start of important evidence.
    marker_impend : str, Marker indicating the end of important evidence.
    layer_span : tuple of int, Range of layers to consider for evidence selection.
    threshold : float, Threshold for selecting evidence sentences.

    '''
    # Compute attention scores for the specified range of layers, mean
    # attention: (batch_size, num_heads, generated_length, sequence_length)
    assert len(output_att) == n_layers, "Compute attention scores for the specified range of layers of one generated token."
    att_layer_scores = np.array(
            [
                output_att[l][0, :, -1, context_span[0] : context_span[1]]
                .detach()
                .cpu()
                .float()
                .numpy()
                .mean(axis=0)
                for l in range(layer_span[0], layer_span[1])
            ]
        )
    # Normalize the attention scores across layers.
    if use_norm:
        att_layer_scores /= att_layer_scores.sum(axis=1, keepdims=True)
    # Aggregate token-level scores into group-level scores.
    att_token_scores = att_layer_scores.mean(axis=0)
    if return_tokens:
        return att_token_scores
    group_scores = np.array(
        [
            att_token_scores[item_span[0]: item_span[1]].mean()
            for item_span in item_spans
        ]
    )
    #for i, s in zip(items, group_scores):
    #    print(f"context:\n<a>{i}<b>\nscore:\n{s}")
    # Select group with scores exceeding the threshold. relative
    target_group_index = (group_scores >= group_scores.max() * threshold).nonzero()[0]
    sorted_index = np.argsort(-group_scores)
    #print(sorted_index)
    #print(target_group_index)
    return group_scores
    # sent elicitation
    '''
    elicited_context = "" # 其实我不是很需要
    item_end = ""
    elicited_items = []
    for i, item in enumerate(items):
        if i in target_group_index and len(item.replace(" ","")) > 5:
            elicited_context += (
                f"{marker_impstart}{item}{marker_impend}{item_end}"
            )
            elicited_items.append(item)
        else:
            elicited_context += f"{item}{item_end}"

    # Collect token spans for selected evidence sentences.
    elicited_spans = [item_spans[i] for i in target_group_index]
    print(elicited_context)
    return elicited_context, elicited_items, elicited_spans
    '''


def inference_original(temperature, max_new_tokens, eval_data, shots, tokenizer, model, model_type, f, system, processed=False, new_prompt=False):
    i = 0
    tag = "wen" in model_type or "3" in model_type
    for idx, item in enumerate(tqdm(eval_data)):
        if idx < 5:
            print(idx)
            verbose = True
        else:
            verbose = False
        demo = shots
        if new_prompt:
            prompt = system + demo + "\n\n" + item['new_prompt']
        else:
            prompt = system + demo + "\n\n" + item["conversations"][0]["value"]
        golden = item["conversations"][1]["value"]
        input_ids = tokenizer([prompt], return_tensors="pt").input_ids
        input_ids = input_ids.to(device)

        if tag:
            output_ids = model.generate(input_ids, do_sample=True, temperature=temperature,
                                        max_new_tokens=max_new_tokens)
        else:
            try:
                all_output = model.generate(input_ids, do_sample=True, temperature=temperature,
                                            max_new_tokens=max_new_tokens,
                                            return_dict_in_generate=True, output_attentions=True, output_hidden_states=True
                                            )
            except Exception as e:
                print(e)
                continue
            #print(all_output)
            attention = all_output.attentions
            hidden_states = all_output.hidden_states
            output_ids = all_output.sequences
            if verbose:
                print(f"input_id shape: {input_ids.shape}")
                print(f"attention: Tuple (one element for each generated token, {len(attention)}) of tuples (one element for each layer of the decoder, {len(attention[0])}) of torch.FloatTensor of shape (batch_size, num_heads, generated_length, sequence_length).{attention[0][0].shape}")

                print(f"hidden_states shape: Tuple (one element for each generated token, {len(hidden_states)}) of tuples (one element for each layer of the decoder, {len(hidden_states[0])}) of torch.FloatTensor of shape (batch_size, generated_length, hidden_size){hidden_states[0][0].shape}")
                print(f"output_id shape: {output_ids.shape}")
            del all_output, attention, hidden_states

        output_ids = output_ids[0][len(input_ids[0]):]
        output = tokenizer.decode(output_ids)
        for special_token in tokenizer.special_tokens_map.values():
            if isinstance(special_token, list):
                for special_tok in special_token:
                    output = output.replace(special_tok, "")
            else:
                output = output.replace(special_token, "")
        output = output.strip()
        output = output.split('\n\n')[0]
        if not processed:
            f.write(json.dumps({"output": output, "golden": golden}, ensure_ascii=False) + "\n")
        else:
            item['output'] = output
            item['golden'] = golden
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

        if verbose:
            print(prompt)
            print(f"output: {output}")
            print(f"golden: {golden}")
            print('\n')
            #exit(1)

def heatmap(results, x_ids, y_ids, tokenizer):
    #assert len(results) == len(x_ids), f"{len(results)} vs {len(x_ids)}"
    #assert len(results[0]) == len(y_ids), f"{len(results[0])} vs {len(y_ids)}"
    x_tokens = tokenizer.convert_ids_to_tokens(x_ids)
    y_tokens = tokenizer.convert_ids_to_tokens(y_ids)

    plt.figure()
    # 绘制热力图
    plt.figure()
    plt.imshow(results, cmap="viridis")
    plt.colorbar()

    # 设置坐标轴标签
    plt.xticks(range(len(x_tokens)), x_tokens)
    plt.yticks(range(len(y_tokens)), y_tokens)
    #plt.title(f"Attention Layer {layer_idx} Head {head_idx}")
    plt.xlabel("Key Tokens")
    plt.ylabel("Query Tokens")
    plt.tight_layout()
    plt.savefig("test.png")
    plt.show()

def draw_specific_attention(output_ids, attention, ss, se, all_output, tokenizer):
        #attention: 特定某个生成token
        #for i in range(len(output_ids))
        output_idx = 3
        output_token = tokenizer.convert_ids_to_tokens(output_ids)[output_idx]
        print(output_token)
        attention_one_token = attention[output_idx]
        print(f"layer {len(attention_one_token)}, (batch_size, num_heads, generated_length, sequence_length){attention_one_token[0].shape}")
        layer_idx = 30  # 第0层
        head_idx = 0  # 第0个头

        selected_attn = attention_one_token[layer_idx][0, head_idx, -1, ss: se].reshape(1, -1).cpu().numpy()  # 取batch中的第一个样本
        print(f"layer 30, head0, attention_one_token[layer_idx][0, head_idx] = {selected_attn.shape}")
        #print(attention_one_token[layer_idx])
        #print(attention_one_token[layer_idx][0, head_idx])
        heatmap(selected_attn,all_output.sequences[0][ss:se],[output_ids[output_idx]], tokenizer)

def reconstruct_input(item, new_doc_orders, qstart):
    question = item['question']
    docs = item['docs']
    docs = [docs[i] for i in new_doc_orders]
    item['new_docs'] = docs
    item['new_order'] = new_doc_orders.tolist()
    doc_prompt = ""
    for doc in docs:
        if doc['title'] != "":
            title = f"{doc['title']}:"
        else:
            title = ""
        doc_prompt += f"{title}{doc['text']}\n"
    if qstart:
        new_prompt = f"Question:{question}\nDocs:{doc_prompt}\nAnswer:"
    else:
        new_prompt =  f"Docs:{doc_prompt}\nQuestion:{question}\nAnswer:"
    item["new_prompt"] = new_prompt
    return new_prompt

def external_factor(docs):
    factors = []
    maps = {
        "high":1.0, "middle":0.95, "low":0.9
    }
    for i, doc in enumerate(docs):
        #print(doc)
        factors.append(maps[doc['cred']])
    return np.array(factors)

def attention_scores_query(output_att, context_span, doc_spans, target, layer_span, n_layers, use_norm=True, return_tokens=False):
    #attention: n_layers (batch_size, num_heads, generated_length, sequence_length)
    assert len(output_att) == n_layers, "Compute attention scores for the specified range of layers of one generated token."
    if type(target) == int:
        att_layer_scores = np.array(
            [
                output_att[l][0, :, target, context_span[0]: context_span[1]]
                    .detach()
                    .cpu()
                    .float()
                    .numpy()
                    .mean(axis=0)
                for l in range(layer_span[0], layer_span[1])
            ]
        )
    else:
        att_target_layer_scores = np.array(
            [
                output_att[l][0, :, target[0]: target[1], context_span[0]: context_span[1]]
                    .detach()
                    .cpu()
                    .float()
                    .numpy()
                    .mean(axis=0)
                for l in range(layer_span[0], layer_span[1])
            ]
        )
        att_layer_scores = att_target_layer_scores.mean(axis=1) #排除query token数的影响
    # Normalize the attention scores across layers.
    if use_norm:
        att_layer_scores /= att_layer_scores.sum(axis=1, keepdims=True)
    # Aggregate token-level scores into group-level scores.
    att_token_scores = att_layer_scores.mean(axis=0)
    if return_tokens:
        return att_token_scores
    group_scores = np.array(
            [
                att_token_scores[item_span[0]: item_span[1]].mean()
                for item_span in doc_spans
            ]
    )
    return group_scores

def attention_scores_output(len_docs, output_ids, attention, tokenizer, n_layers, docs, docs_spans, sents, sent_spans, context_spans, layer_span, method, use_norm=True, return_tokens=False):
    if not return_tokens:
        all_scores = np.zeros(len_docs, dtype=np.float32)
        for i in range(len(output_ids)):
            output_idx = i
            attention_one_token = attention[output_idx]
            output_token = tokenizer.convert_ids_to_tokens(output_ids)[output_idx]
            # print(output_token)
            if method == 0:
            # method0: direct average
                doc_group_scores = span_attention_one_token(attention_one_token, n_layers, docs, docs_spans, context_spans, "<imstart>","<imend>",layer_span,0.5,use_norm)
            # method1: sentence first, then docs
            elif method == 1:
                sentence_group_scores = span_attention_one_token(attention_one_token, n_layers, sents, sent_spans,
                                                                 context_spans, "<imstart>", "<imend>", layer_span, 0.5,use_norm)
                # print(len(sentence_group_scores))
                # print(sentence_group_scores)
                doc_sent = get_doc_sentence_span(sent_spans, docs_spans)
                # print(doc_sent)
                doc_group_scores = np.array(
                    [
                        sentence_group_scores[id[0]: id[-1] + 1].mean()
                        for id in doc_sent
                    ]
                )

            # method2: sentence first, then docs, sentence mean, doc sum
            elif method == 2:
                sentence_group_scores = span_attention_one_token(attention_one_token, n_layers, sents, sent_spans, context_spans, "<imstart>","<imend>",layer_span,0.5, use_norm)
                #print(len(sentence_group_scores))
                #print(sentence_group_scores)
                doc_sent = get_doc_sentence_span(sent_spans, docs_spans)
                #print(doc_sent)
                doc_group_scores = np.array(
                    [
                        sentence_group_scores[id[0]: id[-1] + 1].sum()
                        for id in doc_sent
                    ]
                )
            else:
                raise NotImplementedError
            # print(len(doc_group_scores))
            # print(doc_group_scores)
            all_scores += doc_group_scores
        if len(output_ids) > 0:
            all_scores /= len(output_ids)
        return all_scores
    else:
        all_token_scores = np.zeros(context_spans[1] - context_spans[0], dtype=np.float32)
        for i in range(len(output_ids)):
            output_idx = i
            attention_one_token = attention[output_idx]
            output_token = tokenizer.convert_ids_to_tokens(output_ids)[output_idx]
            # print(output_token)
            if method == 0:
                # method0: direct average
                token_scores = span_attention_one_token(attention_one_token, n_layers, docs, docs_spans,
                                                            context_spans, "<imstart>", "<imend>", layer_span, 0.5,
                                                            use_norm, return_tokens = True)
            # method1: sentence first, then docs
            else:
                raise NotImplementedError
            # print(len(doc_group_scores))
            # print(doc_group_scores)
            all_token_scores += token_scores
        if len(output_ids) > 0:
            all_token_scores /= len(output_ids)
        return all_token_scores

def inference_attention_merge(temperature, max_new_tokens, eval_data, shots, tokenizer, model, model_type, f, system, processed=False, qstart=False):
    layer0, layer1 = 0.5, 1
    print("in inference_attention_merge")
    '''
    blank_token = '<blank>'
    if blank_token not in tokenizer.vocab:
        tokenizer.add_tokens([blank_token])
        model.resize_token_embeddings(len(tokenizer))
        print("Adding blank tokens")
    blank_id = tokenizer.convert_tokens_to_ids(blank_token)
    '''
    blank_id = tokenizer.pad_token_id
    print(f" blank token: id is {blank_id}")
    passes = 0
    for idx, item in enumerate(tqdm(eval_data)):
        if idx < 5:
            print(idx)
            verbose = True
        else:
            verbose = False
        demo = shots
        #prompt = system + demo + "\n\n" + item["conversations"][0]["value"]
        if qstart:
            prompt = system + demo + "\n\n" + f"Question:{item['question']}\nDocs:{item['doc_prompt']}\nAnswer:"
        else:
            prompt = system + demo + "\n\n" + f"Docs:{item['doc_prompt']}\nQuestion:{item['question']}\nAnswer:"
        golden = item["conversations"][1]["value"]
        input_ids = tokenizer([prompt], return_tensors="pt").input_ids
        #get_position(input_ids, tokenizer, model, verbose=True)
        input_ids = input_ids.to(device)
        context = item['doc_prompt']
        question = item['question']
        context_spans, context_ids = get_context_ids(input_ids,context, tokenizer)
        question_spans, question_ids = get_context_ids(input_ids, question, tokenizer)
        if verbose:
            print(question_ids)
            print(question_spans)
            print(tokenizer.decode(question_ids[0]))
        sent_spans, sents = get_sentence_token_spans(context_ids, tokenizer)
        #test_spans(sent_spans, sents, context_ids, tokenizer)
        docs_spans, docs = get_document_token_spans(context_ids, tokenizer)
        #test_spans(docs_spans, docs, context_ids, tokenizer)
        if len(docs) != len(item["docs"]):
            print(idx)
            print(f"len(docs, {len(docs)}) != len(item[docs], {len(item['docs'])})")
            print(f"prompt:\n{prompt}")
            print(f"docs:\n{docs}")
            print(f"doc_span:\n{docs_spans}")
            print(f"item[docs]:\n{item['docs']}")
            #test_spans(docs_spans, docs, context_ids, tokenizer)
            continue
        try:
            all_output = model.generate(input_ids, do_sample=True, temperature=temperature,
                                        max_new_tokens=max_new_tokens,
                                        return_dict_in_generate=True, output_attentions=True,)
        except Exception as e:
            print(f"Exception1: {e} in {idx}")
            #exit(1)
            passes += 1
            continue
        #print(all_output)
        attention = all_output.attentions
        #hidden_states = all_output.hidden_states
        output_ids = all_output.sequences
        if verbose:
            print(f"input_id shape: {input_ids.shape}")
            print(f"attention: Tuple (one element for each generated token, {len(attention)}) of tuples (one element for each layer of the decoder, {len(attention[0])}) of torch.FloatTensor of shape (batch_size, num_heads, generated_length, sequence_length).{attention[0][0].shape}")

            #print(f"hidden_states shape: Tuple (one element for each generated token, {len(hidden_states)}) of tuples (one element for each layer of the decoder, {len(hidden_states[0])}) of torch.FloatTensor of shape (batch_size, generated_length, hidden_size){hidden_states[0][0].shape}")
            print(f"output_id shape: {output_ids.shape}")
        output_ids = output_ids[0][len(input_ids[0]):]
        if verbose:
            print(output_ids)
        output, end = get_output(output_ids, tokenizer)
        if verbose:
            print(prompt)
            print(output)
            print(f"end is {end}, {output_ids[end]}, <a>{tokenizer.decode(output_ids[end])}</a>")

        n_layers = len(attention[0])
        layer_span = (
            int(layer0 * n_layers),int(layer1 * n_layers)
        )
        first_layer_span = (0, int(layer0 * n_layers))
        all_layer_span = (0, n_layers)
        last_layer_span = layer_span
        # attention可视化，
        #print(sents[0])
        #ss, se = context_spans[0] + sent_spans[0][0], context_spans[0] + sent_spans[0][1]
        #draw_specific_attention(output_ids,attention, ss, se, all_output, tokenizer)
        # attention score group -- output
        '''
        method = 0
        all_scores_aall = attention_scores_output(len(docs), output_ids, attention, tokenizer, n_layers, docs,
                                                  docs_spans, sents, sent_spans, context_spans, layer_span, method,
                                                  use_norm=True)
        all_scores_a = attention_scores_output(len(docs), output_ids[:end], attention, tokenizer, n_layers, docs,
                                                  docs_spans, sents, sent_spans, context_spans, last_layer_span, method,
                                                  use_norm=True)
        all_scores_a_1 = attention_scores_query(attention[-1], context_spans, docs_spans, -1,
                                                            first_layer_span, n_layers, use_norm=False)
        '''
        # attention score group -- query
        '''
        if verbose:
            print(attention[0][0].shape)
        #all_scores_q_1 = attention_scores_query(attention[0], context_spans, docs_spans, -1, layer_span, n_layers,use_norm=True)
        all_scores_q = attention_scores_query(attention[0], context_spans, docs_spans, question_spans, all_layer_span, n_layers,use_norm=False)
        '''
        #method_ef: external factor
        #e_f = external_factor(item['docs'])
        #all_scores = all_scores * e_f
        #all_scores = all_scores_a
        #all_scores = np.abs(all_scores_a - all_scores_a_1)
        #all_scores = all_scores_q
        # attention score token -- output

        method = 0
        token_scores_aall = attention_scores_output(len(docs), output_ids, attention, tokenizer, n_layers, docs,
                                                  docs_spans, sents, sent_spans, context_spans, layer_span, method,
                                                  use_norm=False,return_tokens=True)
        token_scores_a = attention_scores_output(len(docs), output_ids[:end], attention, tokenizer, n_layers, docs,
                                               docs_spans, sents, sent_spans, context_spans, last_layer_span, method,
                                               use_norm=True, return_tokens=True)
        tokens_scores_q = attention_scores_query(attention[0], context_spans, docs_spans, question_spans, last_layer_span,
                                                 n_layers,use_norm=True, return_tokens=True)
        token_scores_end = attention_scores_query(attention[end], context_spans, docs_spans, -1,
                                                first_layer_span, n_layers, use_norm=False, return_tokens=True)
        token_scores_begin = attention_scores_query(attention[0], context_spans, docs_spans, -2,
                                                first_layer_span, n_layers, use_norm=False, return_tokens=True)
        abs_sub_tokens = np.abs(token_scores_a - (token_scores_end + token_scores_begin) / 2 )
        group_scores = np.array(
            [
                abs_sub_tokens[item_span[0]: item_span[1]].mean()
                for item_span in docs_spans
            ]
        )
        all_scores = group_scores

        # attention score token mvsub
        '''
        method = 0
        token_scores_aall = attention_scores_output(len(docs), output_ids, attention, tokenizer, n_layers, docs,
                                                    docs_spans, sents, sent_spans, context_spans, layer_span, method,
                                                    use_norm=False, return_tokens=True)
        token_scores_a = attention_scores_output(len(docs), output_ids[:end], attention, tokenizer, n_layers, docs,
                                                 docs_spans, sents, sent_spans, context_spans, last_layer_span, method,
                                                 use_norm=False, return_tokens=True)
        tokens_scores_q = attention_scores_query(attention[0], context_spans, docs_spans, question_spans,
                                                 last_layer_span,
                                                 n_layers, use_norm=False, return_tokens=True)
        token_scores_end = attention_scores_query(attention[end], context_spans, docs_spans, -1,
                                                  last_layer_span, n_layers, use_norm=False, return_tokens=True)
        token_scores_begin = attention_scores_query(attention[0], context_spans, docs_spans, -2,
                                                    last_layer_span, n_layers, use_norm=False, return_tokens=True)
        token_scores = token_scores_a - (token_scores_begin + token_scores_end) / 2
        all_scores = np.zeros(len(docs_spans), dtype=np.float32)
        for i in range(len(docs_spans)):
            scores = token_scores[docs_spans[i][0]: docs_spans[i][1]]
            # print(f"scores: {scores}")
            mean = scores.mean()
            var = np.var(scores)
            # print(f"mean: {mean}, var: {var}")
            filter_scores = np.array([s for s in scores if s > mean - 2 * var])
            # print(f"filter_scores: {filter_scores}, sum: {filter_scores.sum()}")
            all_scores[i] = filter_scores.sum()
        '''
        # combined_index

        if qstart:
            combined_index = np.argsort(-all_scores)
        else:
            combined_index = np.argsort(all_scores)

        #combined_index = np.arange(len(docs))
        #combined_index = combined_index[::-1]
        if verbose:
            print(f"attention: {combined_index}")
        '''
        item['att_order'] = combined_index.tolist()
        try:
            position_index = position_order(blank_id, input_ids, device, 1, model, context_spans, docs_spans,
                                            layer_span,n_layers, verbose)
        except Exception as e:
            print(f"{e} in {idx}")
            position_index = np.arange(len(combined_index))
            #exit(2)
        '''
        '''
        item['att_order'] = combined_index.tolist()
        all_scores_end = attention_scores_query(attention[end], context_spans, docs_spans, -1,
                                                            first_layer_span, n_layers, use_norm=False)
        all_scores_b = attention_scores_query(attention[0], context_spans, docs_spans, -2,
                                                            first_layer_span, n_layers, use_norm=False)
        all_scores_p = (all_scores_end + all_scores_b) / 2
        #method = 0
        #all_scores_a = attention_scores_output(len(docs), output_ids[:end], attention, tokenizer, n_layers, docs,
        #                                          docs_spans, sents, sent_spans, context_spans, first_layer_span, method,
        #                                          use_norm=False)
        #all_scores_p = all_scores_a
        if qstart:
            position_index = np.argsort(-all_scores_p)
        else:
            position_index = np.argsort(all_scores_p)
        #position_index = position_index[::-1]
        item['position_order'] = position_index.tolist()
        new_order = np.zeros_like(position_index)
        new_order[position_index] = combined_index
        combined_index = new_order
        if verbose:
            print(f"position: {position_index}")
            print(f"attention+position: {combined_index}")
        '''
        '''
        item['att_order'] = combined_index.tolist()
        token_scores_end = attention_scores_query(attention[end], context_spans, docs_spans, -1,
                                                first_layer_span, n_layers, use_norm=False, return_tokens=True)
        token_scores_b = attention_scores_query(attention[0], context_spans, docs_spans, -2,
                                              first_layer_span, n_layers, use_norm=False, return_tokens=True)
        token_scores_p = (token_scores_end + token_scores_b) / 2
        all_scores_p = np.zeros(len(docs_spans), dtype=np.float32)
        for i in range(len(docs_spans)):
            scores = token_scores_p[docs_spans[i][0]: docs_spans[i][1]]
            # print(f"scores: {scores}")
            mean = scores.mean()
            var = np.var(scores)
            # print(f"mean: {mean}, var: {var}")
            filter_scores = np.array([s for s in scores if s > mean - 2 * var])
            # print(f"filter_scores: {filter_scores}, sum: {filter_scores.sum()}")
            all_scores_p[i] = filter_scores.sum()
        if qstart:
            position_index = np.argsort(-all_scores_p)
        else:
            position_index = np.argsort(all_scores_p)
        #position_index = position_index[::-1]
        item['position_order'] = position_index.tolist()
        new_order = np.zeros_like(position_index)
        new_order[position_index] = combined_index
        combined_index = new_order
        if verbose:
            print(f"position: {position_index}")
            print(f"attention+position: {combined_index}")
        '''

        item['att_order'] = combined_index.tolist()
        token_scores_end = attention_scores_query(attention[end], context_spans, docs_spans, -1,
                                                  first_layer_span, n_layers, use_norm=False, return_tokens=True)
        token_scores_b = attention_scores_query(attention[0], context_spans, docs_spans, -2,
                                                first_layer_span, n_layers, use_norm=False, return_tokens=True)
        token_scores_p = (token_scores_end + token_scores_b) / 2
        #position_index = np.zeros(len(docs_spans), dtype=np.float32)
        # att_order, 最相关的在最前
        if qstart:
            att_order = combined_index
        else:
            att_order = combined_index[::-1]
        #att_order = att_order[::-1]
        assert context_spans[1] - context_spans[0] == len(token_scores_p), f"{context_spans[1] - context_spans[0]} vs {len(token_scores_p)}"
        left, right = 0, context_spans[1] - context_spans[0]
        lefts, rights = [], []
        for att_idx in att_order:
            doc_len = docs_spans[att_idx][1] - docs_spans[att_idx][0]
            scores_left = token_scores_p[left: left + doc_len].mean()
            scores_right = token_scores_p[right - doc_len: right].mean()
            if scores_left > scores_right:
                lefts.append(att_idx)
                left = left + doc_len
            else:
                rights.append(att_idx)
                right = right - doc_len
        assert len(lefts) + len(rights) == len(docs_spans), f"{len(lefts)},{len(rights)},{len(docs_spans)}"
        position_index = np.array(lefts + rights[::-1])
        item['position_order'] = position_index.tolist()
        #new_order = np.zeros_like(position_index)
        #new_order[position_index] = combined_index
        combined_index = position_index
        if verbose:
            print(f"position: {position_index}")
            #print(f"attention+position: {combined_index}")

        del all_output, attention
        new_prompt = reconstruct_input(item, combined_index, qstart)
        if verbose:
            print(new_prompt)
        prompt = system + demo + "\n\n" + new_prompt
        input_ids = tokenizer([prompt], return_tensors="pt").input_ids
        input_ids = input_ids.to(device)
        try:
            output_ids = model.generate(input_ids, do_sample=True, temperature=temperature,
                                        max_new_tokens=max_new_tokens)
        except Exception as e:
            print(f"Exception2: {e} in {idx}")
            #exit(2)
            continue
        output_ids = output_ids[0][len(input_ids[0]):]
        output, end = get_output(output_ids, tokenizer)

        if not processed:
            f.write(json.dumps({"output": output, "golden": golden}, ensure_ascii=False) + "\n")
        else:
            item['output'] = output
            item['golden'] = golden
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

        if verbose:
            #print(prompt)
            print(f"output: {output}")
            print(f"golden: {golden}")
            print('\n')
            #exit(1)

def main():
    set_seed(42)
    args = parser()
    assert args.setting_type is not None, "Setting type is required in single scenario!"
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        padding_side="left"
    )
    if "wen" in args.model_path or "3" in args.model_path:
        dtype = torch.bfloat16
    else:
        dtype = torch.float16
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=dtype,
        device_map="auto",
        output_attentions=True,
    )
    data_list = []
    if args.hotpot:
        data_list.append("HotpotQA")
    if args.musique:
        data_list.append("Musique")
    if args.wikimulti:
        data_list.append("2wikiMultiHopQA")
    system_prompt = get_system_prompt(args.setting_type)
    print(data_list)
    for data_name in data_list:
        if args.qstart:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}_qstart")
            if args.zero_shot:
                output_path = os.path.join("./test_zs", data_name,
                                           f"{args.model_type}_{args.save_suffix}_tmp{args.temperature}_qstart_pr.json")
                shots = ""
            else:
                output_path = os.path.join("./test", data_name,
                                           f"{args.model_type}_{args.save_suffix}_tmp{args.temperature}_qstart.json")
                with open(f'./prompt/{data_name}.txt', 'r') as f_shot:
                    shots = f_shot.read()
        else:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}")
            if args.zero_shot:
                output_path = os.path.join("./test_zs", data_name,
                                           f"{args.model_type}_{args.save_suffix}_tmp{args.temperature}_rasubpu.json")
                shots = ""
            else:
                output_path = os.path.join("./test", data_name,
                                           f"{args.model_type}_{args.save_suffix}_tmp{args.temperature}.json")
                with open(f'./prompt/{data_name}.txt', 'r') as f_shot:
                    shots = f_shot.read()
        if args.debug:
            output_path = "delete/test.json"
        with open(output_path, "w") as f:

            inference_attention_merge(args.temperature, args.max_new_tokens, eval_data, shots, tokenizer, model,
                               args.model_type, f, system_prompt, args.processed, args.qstart)
            '''
            inference_original(args.temperature, args.max_new_tokens, eval_data, shots, tokenizer, model,
                                      args.model_type, f, system_prompt, args.processed)
            '''


        compute_exact_match(output_path, data_name)

def layer_scores():
    print("in layer scores")
    set_seed(42)
    args = parser()
    assert args.setting_type is not None, "Setting type is required in single scenario!"
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        padding_side="left"
    )
    if "wen" in args.model_path or "3" in args.model_path:
        dtype = torch.bfloat16
    else:
        dtype = torch.float16
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=dtype,
        device_map="auto",
        output_attentions=True,
    )
    data_list = []
    if args.hotpot:
        data_list.append("HotpotQA")
    if args.musique:
        data_list.append("Musique")
    system_prompt = get_system_prompt(args.setting_type)
    print(data_list)
    for data_name in data_list:
        if args.qstart:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}_qstart")
        else:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}")
        if args.zero_shot:
            shots = ""
        else:
            with open(f'./prompt/{data_name}.txt', 'r') as f_shot:
                shots = f_shot.read()
        if data_name == "HotpotQA":
            doc_num = 10
        else:
            doc_num = 20
        if "wen" in args.model_type:
            layers = 28
        else:
            layers = 32
        all_scores = np.zeros(shape=(layers,doc_num))
        passes = 0
        scores_record = []
        for idx, item in enumerate(tqdm(eval_data)):
            if len(item['docs']) != doc_num:
                passes += 1
                continue
            #print(idx)
            demo = shots
            prompt = system_prompt + demo + "\n\n" + item["conversations"][0]["value"]
            golden = item["conversations"][1]["value"]
            #print(prompt)
            #print(golden)
            input_ids = tokenizer([prompt], return_tensors="pt").input_ids
            input_ids = input_ids.to(device)
            context = item['doc_prompt']
            context_spans, context_ids = get_context_ids(input_ids, context, tokenizer)
            sent_spans, sents = get_sentence_token_spans(context_ids, tokenizer)
            # test_spans(sent_spans, sents, context_ids, tokenizer)
            docs_spans, docs = get_document_token_spans(context_ids, tokenizer)
            # test_spans(docs_spans, docs, context_ids, tokenizer)
            if len(docs) != len(item["docs"]):
                print(idx)
                print(f"len(docs, {len(docs)}) != len(item[docs], {len(item['docs'])})")
                print(f"prompt:\n{prompt}")
                print(f"docs:\n{docs}")
                print(f"doc_span:\n{docs_spans}")
                print(f"item[docs]:\n{item['docs']}")
                # test_spans(docs_spans, docs, context_ids, tokenizer)
                continue
            try:
                all_output = model.generate(input_ids, do_sample=True, temperature=args.temperature,
                                            max_new_tokens=args.max_new_tokens,
                                            return_dict_in_generate=True, output_attentions=True,)
            except Exception as e:
                print(f"Exception1: {e} in {idx}")
                passes += 1
                continue
            # print(all_output)
            attention = all_output.attentions
            output_ids = all_output.sequences
            output_ids = output_ids[0][len(input_ids[0]):]
            output = tokenizer.decode(output_ids)
            for special_token in tokenizer.special_tokens_map.values():
                if isinstance(special_token, list):
                    for special_tok in special_token:
                        output = output.replace(special_tok, "")
                else:
                    output = output.replace(special_token, "")
            output = output.strip()
            output = output.split('\n\n')[0]
            #print(output)

            n_layers = len(attention[0])
            #attention: (batch_size, num_heads, generated_length, sequence_length)
            spans = docs_spans
            #all_one_scores = np.zeros(shape=(n_layers,len(spans)))
            all_scores = np.zeros(shape=(n_layers, context_spans[1] - context_spans[0]))
            #print(all_one_scores.shape)
            for i in range(len(output_ids)):
                output_idx = i
                attention_one_token = attention[output_idx]
                output_token = tokenizer.convert_ids_to_tokens(output_ids)[output_idx]
                #print(output_token)
                att_layer_scores = np.array(
                    [
                        attention_one_token[l][0, :, -1, context_spans[0]: context_spans[1]]
                            .detach()
                            .cpu()
                            .float()
                            .numpy()
                            .mean(axis=0)
                        for l in range(0, n_layers)
                    ]
                )
                # Normalize the attention scores across layers.
                #print(att_layer_scores.size())
                #att_layer_scores /= att_layer_scores.sum(axis=1, keepdims=True)
                #att_item_scores = np.hstack([
                #    att_layer_scores[:, item[0]:item[1]].mean(axis=1, keepdims=True) for item in spans])
                #all_one_scores += att_item_scores
                all_scores += att_layer_scores

            #print(all_one_scores)
            #attention_layer_scores = all_one_scores
            #draw_layer_scores(attention_layer_scores)
            #exit(1)
            #all_one_scores /= len(output_ids)
            #scores_record.append(all_one_scores.tolist())
            #all_scores += all_one_scores
            #del all_output, attention
            draw_layer_scores(all_scores, f"interpre/layer_scores/{args.model_type}_{data_name}_{args.setting_type}_layer_context_scores")
            exit(1)

        draw_layer_scores(all_scores, f"interpre/{args.model_type}_{data_name}_{args.setting_type}")
        with open(f"interpre/{args.model_type}_{data_name}_{args.setting_type}_scores.json",'w') as f:
            json.dump({"all_scores":all_scores.tolist(),"passes": passes, "scores": scores_record}, f, ensure_ascii=False)
        print(all_scores)
        print(f"passes = {passes}")

def load_result_data(data_path, data_name, target_model, setting_type, result_suffix, qstart):
    if qstart:
        data_file = f"{target_model}_{setting_type}_tmp0.01_{result_suffix}_qstart.json"
    else:
        data_file = f"{target_model}_{setting_type}_tmp0.01_{result_suffix}.json"
    print(data_file)
    output_path = os.path.join(data_path, data_name, data_file)
    data = []
    with open(output_path, "r") as f:
        for line in f:
            data.append(json.loads(line))
    print(len(data))
    return data

def main_orders():
    set_seed(42)
    args = parser()
    assert args.setting_type is not None, "Setting type is required in single scenario!"
    assert args.result_suffix is not None, "Result_suffix is required in special order!"
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        padding_side="left"
    )
    if "wen" in args.model_path or "3" in args.model_path:
        dtype = torch.bfloat16
    else:
        dtype = torch.float16
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=dtype,
        device_map="auto",
        output_attentions=True,
    )
    data_list = []
    if args.hotpot:
        data_list.append("HotpotQA")
    if args.musique:
        data_list.append("Musique")
    if args.wikimulti:
        data_list.append("2wikiMultiHopQA")
    system_prompt = get_system_prompt(args.setting_type)
    print(data_list)
    target_model = "qwen_7b"
    print(target_model)
    for data_name in data_list:
        eval_data = load_result_data(args.data_path,data_name,target_model,args.setting_type,args.result_suffix,args.qstart)
        save_suffix = args.save_suffix + f"_{target_model}_order"
        if args.qstart:
            if args.zero_shot:
                output_path = os.path.join("./test_zs", data_name,
                                           f"{args.model_type}_{save_suffix}_tmp{args.temperature}_qstart.json")
                shots = ""
            else:
                output_path = os.path.join("./test", data_name,
                                           f"{args.model_type}_{save_suffix}_tmp{args.temperature}_qstart.json")
                with open(f'./prompt/{data_name}.txt', 'r') as f_shot:
                    shots = f_shot.read()
        else:
            if args.zero_shot:
                output_path = os.path.join("./test_zs", data_name,
                                           f"{args.model_type}_{save_suffix}_tmp{args.temperature}.json")
                shots = ""
            else:
                output_path = os.path.join("./test", data_name,
                                           f"{args.model_type}_{save_suffix}_tmp{args.temperature}.json")
                with open(f'./prompt/{data_name}.txt', 'r') as f_shot:
                    shots = f_shot.read()
        with open(output_path, "w") as f:
            '''
            inference_attention_merge(args.temperature, args.max_new_tokens, eval_data, shots, tokenizer, model,
                               args.model_type, f, system_prompt, args.processed, args.qstart)
            '''
            inference_original(args.temperature, args.max_new_tokens, eval_data, shots, tokenizer, model,
                                      args.model_type, f, system_prompt, args.processed, new_prompt=True)



        compute_exact_match(output_path, data_name)

def get_position(pad_token_id, input_ids, tokenizer, model=None, num_pad_tokens = 10, verbose = False):
    if verbose:
        print(f"pad_token: {tokenizer.pad_token}, id: {tokenizer.pad_token_id}")
        print(f"eos_token: {tokenizer.eos_token}, id: {tokenizer.eos_token_id}")
        print(f"unk_token: {tokenizer.unk_token}, id: {tokenizer.unk_token_id}")
        print(f"input_id shape: {input_ids.shape}")
        print(input_ids)
    #pad_token_id = tokenizer.pad_token_id
    # 添加占位符
    pad_tensor = torch.tensor([[pad_token_id] * num_pad_tokens])
    if verbose:
        print(pad_tensor.shape)
        print(pad_tensor)
    new_input_ids = torch.cat((input_ids, pad_tensor), dim = 1)
    return new_input_ids
    if verbose:
        print(f"new_input_id shape: {new_input_ids.shape}")
        print(new_input_ids)
    new_input_ids = new_input_ids.to(device)
    all_output1 = model.generate(new_input_ids, do_sample=True, temperature=0.01,
                                max_new_tokens=512, output_attentions = True, return_dict_in_generate=True)
    attention1 = all_output1.attentions[0]
    print(f"{len(attention1)},{attention1[0].shape}")
    del all_output1
    all_output2 = model(new_input_ids, output_attentions=True)
    attention2 = all_output2.attentions
    del all_output2
    print(f"{len(attention2)},{attention2[0].shape}")
    print(torch.allclose(attention1[0], attention2[0]))
    diff = (attention1[0] - attention2[0]).abs().max()
    print(f"最大绝对误差: {diff.item()}")
    exit(2)

def position_order(pad_token_id, input_ids, device,num_pad_tokens, model,context_spans, docs_spans, layer_span, n_layers, verbose = False):
    if verbose:
        print(f"padding {num_pad_tokens} tokens with {pad_token_id}")
    pad_tensor = torch.tensor([[pad_token_id] * num_pad_tokens])
    if verbose:
        print(pad_tensor.shape)
        print(pad_tensor)
    new_input_ids = torch.cat((input_ids, pad_tensor.to(device)), dim=1)
    new_input_ids = new_input_ids.to(device)
    all_output = model(new_input_ids, output_attentions=True)
    attention = all_output.attentions
    split_spans = [len(input_ids[0]), len(new_input_ids[0])]
    if verbose:
        print(split_spans)
    all_scores_blank = attention_scores_query(attention, context_spans, docs_spans, split_spans, layer_span,
                                              n_layers)
    #all_scores_first_blank = attention_scores_query(attention, context_spans, docs_spans, split_spans[0], layer_span, n_layers)
    del all_output, attention
    position_index = np.argsort(all_scores_blank)
    if verbose:
        print(f"position_index: {position_index}")
    return position_index


def differ_scores():
    print("in differ scores")
    set_seed(42)
    args = parser()
    assert args.setting_type is not None, "Setting type is required in single scenario!"
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        padding_side="left"
    )
    if "wen" in args.model_path or "3" in args.model_path:
        dtype = torch.bfloat16
    else:
        dtype = torch.float16
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=dtype,
        device_map="auto",
        output_attentions=True,
    )

    blank_token = '<blank>'
    tokenizer.add_tokens([blank_token])
    model.resize_token_embeddings(len(tokenizer))
    blank_id = tokenizer.convert_tokens_to_ids(blank_token)
    print(f"Adding blank token: id is {blank_id}")

    data_list = []
    if args.hotpot:
        data_list.append("HotpotQA")
    if args.musique:
        data_list.append("Musique")
    if args.wikimulti:
        data_list.append("2wikiMultiHopQA")
    system_prompt = get_system_prompt(args.setting_type)
    print(data_list)
    for data_name in data_list:
        if args.qstart:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}_qstart")
        else:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}")
        if args.zero_shot:
            shots = ""
        else:
            with open(f'./prompt/{data_name}.txt', 'r') as f_shot:
                shots = f_shot.read()
        if "QA" in data_name:
            doc_num = 10
        else:
            doc_num = 20
        passes = 0
        scores_record = []
        for idx, item in enumerate(tqdm(eval_data)):
            record = {"idx": idx,"question": item['question'], 'original_id':{}, 'new_id':{}}
            verbose = False
            if idx < 3:
                verbose = True
            if len(item['docs']) != doc_num:
                passes += 1
                print(f"Doc num wrong in {idx}")
                continue
            # print(idx)
            demo = shots
            question = item['question']
            if data_name == "2wikiMultiHopQA":
                prompt = f"Docs:{item['doc_prompt']}\nQuestion:{question}\nAnswer:"
            else:
                prompt = item["conversations"][0]["value"]
            prompt = system_prompt + demo + "\n\n" + prompt
            golden = item["conversations"][1]["value"]
            # print(prompt)
            # print(golden)
            input_ids = tokenizer([prompt], return_tensors="pt").input_ids
            input_ids = input_ids.to(device)
            context = item['doc_prompt']

            context_spans, context_ids = get_context_ids(input_ids, context, tokenizer)
            sent_spans, sents = get_sentence_token_spans(context_ids, tokenizer)
            # test_spans(sent_spans, sents, context_ids, tokenizer)
            docs_spans, docs = get_document_token_spans(context_ids, tokenizer)
            # test_spans(docs_spans, docs, context_ids, tokenizer)
            question_spans, question_ids = get_context_ids(input_ids, question, tokenizer)
            if len(docs) != len(item["docs"]):
                print(idx)
                print(f"len(docs, {len(docs)}) != len(item[docs], {len(item['docs'])})")
                print(f"prompt:\n{prompt}")
                print(f"docs:\n{docs}")
                print(f"doc_span:\n{docs_spans}")
                print(f"item[docs]:\n{item['docs']}")
                # test_spans(docs_spans, docs, context_ids, tokenizer)
                continue
            try:
                all_output = model.generate(input_ids, do_sample=True, temperature=args.temperature,
                                            max_new_tokens=args.max_new_tokens,
                                            return_dict_in_generate=True, output_attentions=True, )

                # print(all_output)
                attention = all_output.attentions
                output_ids = all_output.sequences
                output_ids = output_ids[0][len(input_ids[0]):]
                output = tokenizer.decode(output_ids)
                for special_token in tokenizer.special_tokens_map.values():
                    if isinstance(special_token, list):
                        for special_tok in special_token:
                            output = output.replace(special_tok, "")
                    else:
                        output = output.replace(special_token, "")
                output = output.strip()
                output = output.split('\n\n')[0]
                # print(output)
                record['output'] = output

                n_layers = len(attention[0])
                # attention: (batch_size, num_heads, generated_length, sequence_length)
                layer0, layer1 = 0.5, 1.0
                layer_span = (
                    int(layer0 * n_layers), int(layer1 * n_layers)
                )
                # print(all_one_scores.shape)
                method = 0
                all_scores_a = attention_scores_output(len(docs), output_ids, attention, tokenizer, n_layers, docs,
                                                       docs_spans, sents, sent_spans, context_spans, layer_span, method)
                # attention score group -- query
                if verbose:
                    print(attention[0][0].shape)
                all_scores_q0 = attention_scores_query(attention[0], context_spans, docs_spans, -1, layer_span, n_layers)
                all_scores_q = attention_scores_query(attention[0], context_spans, docs_spans, question_spans, layer_span,
                                                      n_layers)
                record['original_id']['input_ids'] = input_ids.tolist()
                record['original_id']['model_generate'] = {}
                record['original_id']['model'] = {}
                record['original_id']['model_generate']['mean_answer'] = all_scores_a.tolist()
                record['original_id']['model_generate']['mean_question'] = all_scores_q.tolist()
                record['original_id']['model_generate']['mean_input-1'] = all_scores_q0.tolist()
                del all_output, attention
                new_output = model(input_ids, output_attentions=True)
                attention = new_output.attentions
                if verbose:
                    print(attention[0].shape)
                all_scores_q0 = attention_scores_query(attention, context_spans, docs_spans, -1, layer_span, n_layers)
                all_scores_q = attention_scores_query(attention, context_spans, docs_spans, question_spans, layer_span,
                                                      n_layers)
                record['original_id']['model']['mean_question'] = all_scores_q.tolist()
                record['original_id']['model']['mean_input-1'] = all_scores_q0.tolist()
                del new_output, attention
                input_ids = input_ids.cpu()
                new_input_ids = get_position(blank_id,input_ids, tokenizer, num_pad_tokens=len(output_ids), verbose=verbose)
                record['new_id']['input_ids'] = new_input_ids.tolist()
                record['new_id']['model'] = {}
                new_input_ids = new_input_ids.to(device)
                new_output = model(new_input_ids, output_attentions=True)
                attention = new_output.attentions
                if verbose:
                    print(attention[0].shape)
                all_scores_q0 = attention_scores_query(attention, context_spans, docs_spans, -1, layer_span, n_layers)
                all_scores_q = attention_scores_query(attention, context_spans, docs_spans, question_spans, layer_span,
                                                      n_layers)
                record['new_id']['model']['mean_question'] = all_scores_q.tolist()
                record['new_id']['model']['mean_input-1'] = all_scores_q0.tolist()
                split_spans = [len(input_ids[0]), len(new_input_ids[0])]
                if verbose:
                    print(split_spans)
                all_scores_blank = attention_scores_query(attention, context_spans, docs_spans, split_spans, layer_span,
                                                      n_layers)
                record['new_id']['model']['mean_blanks'] = all_scores_blank.tolist()
                all_scores_first_blank = attention_scores_query(attention, context_spans, docs_spans, split_spans[0], layer_span,
                                                      n_layers)
                record['new_id']['model']['mean_firstblank'] = all_scores_first_blank.tolist()
                del new_output, attention
                scores_record.append(record)
            except Exception as e:
                print(f"Exception: {e} in {idx}")
                passes += 1
                continue
            if idx - passes == 50:
                break

        with open(f"interpre/fixed_mean_scores/blank/{args.model_type}_{data_name}_{args.setting_type}_scores.json", 'w') as f:
            json.dump(scores_record, f,ensure_ascii=False)
        print(f"passes = {passes}")

def one_round_scores():
    print("in one_round scores")
    set_seed(42)
    args = parser()
    assert args.setting_type is not None, "Setting type is required in single scenario!"
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        padding_side="left"
    )
    if "wen" in args.model_path or "3" in args.model_path:
        dtype = torch.bfloat16
    else:
        dtype = torch.float16
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=dtype,
        device_map="auto",
        output_attentions=True,
    )

    data_list = []
    if args.hotpot:
        data_list.append("HotpotQA")
    if args.musique:
        data_list.append("Musique")
    if args.wikimulti:
        data_list.append("2wikiMultiHopQA")
    system_prompt = get_system_prompt(args.setting_type)
    print(data_list)
    for data_name in data_list:
        if args.qstart:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}_qstart")
        else:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}")
        if args.zero_shot:
            shots = ""
        else:
            with open(f'./prompt/{data_name}.txt', 'r') as f_shot:
                shots = f_shot.read()
        if "QA" in data_name:
            doc_num = 10
        else:
            doc_num = 20
        passes = 0
        scores_record = []
        for idx, item in enumerate(tqdm(eval_data)):
            record = {"idx": idx,"question": item['question'], 'aall_scores':{}, 'q_scores':{}, 'q-1_scores':{}, 'a-1_scores': {},'a_scores':{}}
            verbose = False
            if idx < 3:
                verbose = True
            if len(item['docs']) != doc_num:
                passes += 1
                print(f"Doc num wrong in {idx}")
                continue
            # print(idx)
            demo = shots
            question = item['question']
            if data_name == "2wikiMultiHopQA":
                prompt = f"Docs:{item['doc_prompt']}\nQuestion:{question}\nAnswer:"
            else:
                prompt = item["conversations"][0]["value"]
            prompt = system_prompt + demo + "\n\n" + prompt
            golden = item["conversations"][1]["value"]
            # print(prompt)
            # print(golden)
            input_ids = tokenizer([prompt], return_tensors="pt").input_ids
            input_ids = input_ids.to(device)
            context = item['doc_prompt']

            context_spans, context_ids = get_context_ids(input_ids, context, tokenizer)
            sent_spans, sents = get_sentence_token_spans(context_ids, tokenizer)
            # test_spans(sent_spans, sents, context_ids, tokenizer)
            docs_spans, docs = get_document_token_spans(context_ids, tokenizer)
            # test_spans(docs_spans, docs, context_ids, tokenizer)
            question_spans, question_ids = get_context_ids(input_ids, question, tokenizer)
            all_output = model.generate(input_ids, do_sample=True, temperature=args.temperature,
                                            max_new_tokens=args.max_new_tokens,
                                            return_dict_in_generate=True, output_attentions=True, )

            attention = all_output.attentions
            output_ids = all_output.sequences
            output_ids = output_ids[0][len(input_ids[0]):]
            len_output_ids = len(output_ids)
            if verbose:
                print(f"len(output_ids) = {len_output_ids}, the last token id is {output_ids[len_output_ids - 1]}")
                print(output_ids)
            record['output_ids'] = output_ids.tolist()
            # print(output)
            record['output'], end = get_output(output_ids,tokenizer)
            n_layers = len(attention[0])
            # attention: (batch_size, num_heads, generated_length, sequence_length)
            layer0, layer1 = 0.5, 1.0
            last_layer_span = (
                int(layer0 * n_layers), int(layer1 * n_layers)
            )
            first_layer_span = (
                0, int(layer0 * n_layers)
            )
            method = 0
            record['aall_scores']['last_norm'] = attention_scores_output(len(docs), output_ids, attention, tokenizer,
                                                                         n_layers, docs,
                                                                         docs_spans, sents, sent_spans, context_spans,
                                                                         last_layer_span, method, use_norm=True).tolist()
            record['aall_scores']['last_nonorm'] = attention_scores_output(len(docs), output_ids, attention, tokenizer,
                                                                         n_layers, docs,
                                                                         docs_spans, sents, sent_spans, context_spans,
                                                                         last_layer_span, method, use_norm=False).tolist()
            record['aall_scores']['first_norm'] = attention_scores_output(len(docs), output_ids, attention, tokenizer,
                                                                           n_layers, docs,
                                                                           docs_spans, sents, sent_spans, context_spans,
                                                                           first_layer_span, method, use_norm=True).tolist()
            record['aall_scores']['first_nonorm'] = attention_scores_output(len(docs), output_ids, attention, tokenizer,
                                                                         n_layers, docs,
                                                                         docs_spans, sents, sent_spans, context_spans,
                                                                         first_layer_span, method, use_norm=False).tolist()


            record['a_scores']['last_norm'] = attention_scores_output(len(docs), output_ids[:-1], attention, tokenizer,
                                                                         n_layers, docs,
                                                                         docs_spans, sents, sent_spans, context_spans,
                                                                         last_layer_span, method, use_norm=True).tolist()
            record['a_scores']['last_nonorm'] = attention_scores_output(len(docs), output_ids[:-1], attention, tokenizer,
                                                                           n_layers, docs,
                                                                           docs_spans, sents, sent_spans, context_spans,
                                                                           last_layer_span, method, use_norm=False).tolist()
            record['a_scores']['first_norm'] = attention_scores_output(len(docs), output_ids[:-1], attention, tokenizer,
                                                                          n_layers, docs,
                                                                          docs_spans, sents, sent_spans, context_spans,
                                                                          first_layer_span, method, use_norm=True).tolist()
            record['a_scores']['first_nonorm'] = attention_scores_output(len(docs), output_ids[:-1], attention, tokenizer,
                                                                            n_layers, docs,
                                                                            docs_spans, sents, sent_spans,
                                                                            context_spans,
                                                                            first_layer_span, method, use_norm=False).tolist()

            record['a-1_scores']['last_norm'] = attention_scores_query(attention[-1], context_spans, docs_spans, -1,
                                                                       last_layer_span, n_layers,use_norm=True).tolist()
            record['a-1_scores']['last_nonorm'] = attention_scores_query(attention[-1], context_spans, docs_spans, -1,
                                                                       last_layer_span, n_layers, use_norm=False).tolist()
            record['a-1_scores']['first_norm'] = attention_scores_query(attention[-1], context_spans, docs_spans, -1,
                                                                       first_layer_span, n_layers, use_norm=True).tolist()
            record['a-1_scores']['first_nonorm'] = attention_scores_query(attention[-1], context_spans, docs_spans, -1,
                                                                        first_layer_span, n_layers, use_norm=False).tolist()

            record['q-1_scores']['last_norm'] = attention_scores_query(attention[0], context_spans, docs_spans, -1,
                                                                       last_layer_span, n_layers, use_norm=True).tolist()
            record['q-1_scores']['last_nonorm'] = attention_scores_query(attention[0], context_spans, docs_spans, -1,
                                                                         last_layer_span, n_layers, use_norm=False).tolist()
            record['q-1_scores']['first_norm'] = attention_scores_query(attention[0], context_spans, docs_spans, -1,
                                                                        first_layer_span, n_layers, use_norm=True).tolist()
            record['q-1_scores']['first_nonorm'] = attention_scores_query(attention[0], context_spans, docs_spans, -1,
                                                                          first_layer_span, n_layers, use_norm=False).tolist()

            record['q_scores']['last_norm'] = attention_scores_query(attention[0], context_spans, docs_spans, question_spans,
                                                                       last_layer_span, n_layers, use_norm=True).tolist()
            record['q_scores']['last_nonorm'] = attention_scores_query(attention[0], context_spans, docs_spans, question_spans,
                                                                         last_layer_span, n_layers, use_norm=False).tolist()
            record['q_scores']['first_norm'] = attention_scores_query(attention[0], context_spans, docs_spans, question_spans,
                                                                        first_layer_span, n_layers, use_norm=True).tolist()
            record['q_scores']['first_nonorm'] = attention_scores_query(attention[0], context_spans, docs_spans, question_spans,
                                                                          first_layer_span, n_layers, use_norm=False).tolist()
            scores_record.append(record)
            del all_output, attention
            if idx == 49:
                break

        with open(f"interpre/one_round_scores/{args.model_type}_{data_name}_{args.setting_type}_scores.json", 'w') as f:
            json.dump(scores_record, f,ensure_ascii=False)
        print(len(scores_record))

def token_scores():
    print("in tokens scores")
    set_seed(42)
    args = parser()
    assert args.setting_type is not None, "Setting type is required in single scenario!"
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path,
        padding_side="left"
    )
    if "wen" in args.model_path or "3" in args.model_path:
        dtype = torch.bfloat16
    else:
        dtype = torch.float16
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=dtype,
        device_map="auto",
        output_attentions=True,
    )

    data_list = []
    if args.hotpot:
        data_list.append("HotpotQA")
    if args.musique:
        data_list.append("Musique")
    if args.wikimulti:
        data_list.append("2wikiMultiHopQA")
    system_prompt = get_system_prompt(args.setting_type)
    print(data_list)
    for data_name in data_list:
        if args.qstart:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}_qstart")
        else:
            eval_data = load_data(args.data_path, data_name, f"{args.setting_type}")
        if args.zero_shot:
            shots = ""
        else:
            with open(f'./prompt/{data_name}.txt', 'r') as f_shot:
                shots = f_shot.read()
        if "QA" in data_name:
            doc_num = 10
        else:
            doc_num = 20
        passes = 0
        scores_record = []
        for idx, item in enumerate(tqdm(eval_data)):
            '''
            record = {"idx": idx, "question": item['question'], 'end_scores': {}, 'b_input-2_scores': {}, 'be_scores': {},
                      'be_token_scores':{}, 'real_a_scores':{},'ra_sub_be_scores':{},'ra_sub_be_token_scores':{},
                      'ra_sub_e_scores':{}, 'ra_sub_e_token_scores':{},'ra_sub_b_scores':{},'ra_sub_b_token_scores':{}}
            '''
            record = {"idx": idx, "question": item['question'], 'pbe_scores':{}}
            verbose = False
            if idx < 3:
                verbose = True
            if len(item['docs']) != doc_num:
                passes += 1
                print(f"Doc num wrong in {idx}")
                continue
            # print(idx)
            demo = shots
            question = item['question']
            if data_name == "2wikiMultiHopQA":
                prompt = f"Docs:{item['doc_prompt']}\nQuestion:{question}\nAnswer:"
            else:
                prompt = item["conversations"][0]["value"]
            prompt = system_prompt + demo + "\n\n" + prompt
            golden = item["conversations"][1]["value"]
            # print(prompt)
            # print(golden)
            input_ids = tokenizer([prompt], return_tensors="pt").input_ids
            input_ids = input_ids.to(device)
            context = item['doc_prompt']

            context_spans, context_ids = get_context_ids(input_ids, context, tokenizer)
            sent_spans, sents = get_sentence_token_spans(context_ids, tokenizer)
            # test_spans(sent_spans, sents, context_ids, tokenizer)
            docs_spans, docs = get_document_token_spans(context_ids, tokenizer)
            # test_spans(docs_spans, docs, context_ids, tokenizer)
            question_spans, question_ids = get_context_ids(input_ids, question, tokenizer)
            all_output = model.generate(input_ids, do_sample=True, temperature=args.temperature,
                                        max_new_tokens=args.max_new_tokens,
                                        return_dict_in_generate=True, output_attentions=True, )

            attention = all_output.attentions
            output_ids = all_output.sequences
            output_ids = output_ids[0][len(input_ids[0]):]
            len_output_ids = len(output_ids)
            if verbose:
                print(f"len(output_ids) = {len_output_ids}, the last token id is {output_ids[len_output_ids - 1]}")
                print(output_ids)
            record['output_ids'] = output_ids.tolist()
            # print(output)
            record['output'], end = get_output(output_ids, tokenizer)
            if verbose:
                print(f"end = {end}, {output_ids[end]}")
            n_layers = len(attention[0])
            # attention: (batch_size, num_heads, generated_length, sequence_length)
            layer0, layer1 = 0.5, 1.0
            last_layer_span = (
                int(layer0 * n_layers), int(layer1 * n_layers)
            )
            first_layer_span = (
                0, int(layer0 * n_layers)
            )
            method = 0
            '''
            all_scores_end = attention_scores_query(attention[end], context_spans, docs_spans, -1,
                                                                        first_layer_span, n_layers, use_norm=False)
            all_scores_input_2 = attention_scores_query(attention[0], context_spans, docs_spans, -2,
                                                                first_layer_span, n_layers, use_norm=False)
            record['end_scores'] = all_scores_end.tolist()
            record['b_input-2_scores'] = all_scores_input_2.tolist()
            '''
            token_scores_end = attention_scores_query(attention[end], context_spans, docs_spans, -1,
                                                      first_layer_span, n_layers, use_norm=False, return_tokens=True)
            token_scores_input_2 = attention_scores_query(attention[0], context_spans, docs_spans, -2,
                                                          first_layer_span, n_layers, use_norm=False, return_tokens=True)
            sum_token_scores_be = (token_scores_end + token_scores_input_2) / 2
            record['pbe_scores'] = sum_token_scores_be.tolist()
            '''
            group_scores = np.array(
                [
                    sum_token_scores_be[item_span[0]: item_span[1]].mean()
                    for item_span in docs_spans
                ]
            )
            record['be_token_scores'] = group_scores.tolist()
            all_scores_be = (all_scores_end + all_scores_input_2) / 2
            record['be_scores'] = all_scores_be.tolist()
            all_reala_scores = attention_scores_output(len(docs), output_ids[:end], attention, tokenizer, n_layers, docs,
                                                  docs_spans, sents, sent_spans, context_spans, last_layer_span, method,
                                                  use_norm=True)
            token_reala_scores = attention_scores_output(len(docs), output_ids[:end], attention, tokenizer, n_layers, docs,
                                                  docs_spans, sents, sent_spans, context_spans, last_layer_span, method,
                                                  use_norm=True, return_tokens=True)
            record['real_a_scores'] = all_reala_scores.tolist()
            record['ra_sub_be_scores'] = np.abs(all_reala_scores, all_scores_be).tolist()
            sub_token_scores = np.abs(token_reala_scores - sum_token_scores_be)
            group_scores = np.array(
                [
                    sub_token_scores[item_span[0]: item_span[1]].mean()
                    for item_span in docs_spans
                ]
            )
            record['ra_sub_be_token_scores'] = group_scores.tolist()
            record['ra_sub_e_scores'] = np.abs(all_reala_scores, all_scores_end).tolist()
            sub_token_scores = np.abs(token_reala_scores - token_scores_end)
            group_scores = np.array(
                [
                    sub_token_scores[item_span[0]: item_span[1]].mean()
                    for item_span in docs_spans
                ]
            )
            record['ra_sub_e_token_scores'] = group_scores.tolist()
            record['ra_sub_b_scores'] = np.abs(all_reala_scores, all_scores_input_2).tolist()
            sub_token_scores = np.abs(token_reala_scores - token_scores_input_2)
            group_scores = np.array(
                [
                    sub_token_scores[item_span[0]: item_span[1]].mean()
                    for item_span in docs_spans
                ]
            )
            record['ra_sub_b_token_scores'] = group_scores.tolist()
            '''
            scores_record.append(record)
            del all_output, attention
            if idx == 49:
                break

        with open(f"interpre/token_scores/{args.model_type}_{data_name}_{args.setting_type}_pbe_scores.json", 'w') as f:
            json.dump(scores_record, f, ensure_ascii=False)
        print(len(scores_record))

if __name__ == '__main__':
    #main()
    layer_scores()
    #main_orders()
    #differ_scores()
    #one_round_scores()
    #token_scores()