from data_utils import load_dataset
from utils import construct_prompt, random_sampling, construct_prompt_without_test, construct_prompt_instance_prompt_text
import numpy as np
import torch
import torch.nn as nn
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, LlamaForCompressionCausalLM, AutoConfig
import argparse
from typing import Dict, Optional, Sequence
import itertools
import json
import random
from openpyxl import Workbook
from scipy import stats
import copy

stopwords = ["i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this", "that", "these", "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having", "do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", "of", "at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", "after", "above", "below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t", "can", "will", "just", "don", "should", "now", 
             '!', "@",  '#', "$", '%', "^", '&', "*", '(', ")", '-', "_", '+', "=", '[', "]", '{', "}",  "|", "\'", ';', "\"", "\'", "<", ">", ",", "." , "?", "/", "\n"
            ]

def acquire_different_type_text(params, adjusted_train_sentences, adjusted_train_labels, tokenizer, device):
    # print("inlcude_punctuation = ", include_punctuation)
    # print("inlcude_template = ", include_template)
    # print("inlcude_answer = ", include_answer)
    # print("only_answer = ", only_answer)
    # print("include_content = ", include_content)
    # print("add_final_article = ", add_final_article)
    if True:
        tmp_params = {}
        for each in params:
            # punctuation_prompt = tmp_params["prompt_prefix"] + '\n\n'
            tmp_params[each] = copy.deepcopy(params[each])
            # if include_colon:
            #     tmp_params["prompt_prefix"] += '  : '
        
        # print("punctuation prompt = ", tmp_params["prompt_prefix"] + '\n\n') 
        # punctuation_prompt = tokenizer(tmp_params["prompt_prefix"] + '\n\n.\n\n,\n\n\'\n\n!\n\n@\n\n#\n\n$\n\n\%\n\n^\n\n&\n\n*\n\n(\n\n)\n\n?\n\n;\n\n-\n\n+\n\n=\n\n_\n\n', return_tensors="pt")['input_ids'].to(device)        
        # str0 = ""
        # for each in stopwords:
        #     str0 = str0 +  ' ' + each

        punctuation_prompt = tokenizer(tmp_params["prompt_prefix"] + '\n\n', return_tensors="pt")['input_ids'].to(device)        

    if True:
        tmp_params = {}
        for each in params:
            tmp_params[each] = copy.deepcopy(params[each])
        # if only_answer_template:
        #     if not_include_colon:
        #         tmp_params["q_prefix"] = ' '  
        #         tmp_params["a_prefix"] = tmp_params["a_prefix"].split(":")[0] + ' '
        #     else:
        #         tmp_params["q_prefix"] = ': '  
        # if only_article_template:
        #     if not_include_colon:
        #         tmp_params["a_prefix"] = ' '  
        #         tmp_params["q_prefix"] = tmp_params["q_prefix"].split(":")[0] + ' '
        #     else:
        #         tmp_params["a_prefix"] = ': '  
        

        tmp_params["prompt_prefix"] = ''
        # if include_answer:
        if isinstance(adjusted_train_sentences[0][0], dict):
            empty_sentences = [{'hypothesis':'', 'premise': ''} for each in adjusted_train_labels[0]]
        else:
            empty_sentences = ['' for each in adjusted_train_labels[0]]
            # if only_answer:
            #     # if not_include_colon:
            #     #     tmp_params["q_prefix"] = ' ' 
            #     #     tmp_params["a_prefix"] = ' '  
            #     # else:
            #     tmp_params["q_prefix"] = ': ' 
            #     tmp_params["a_prefix"] = ': ' 
        template_prompt, _ = construct_prompt_without_test(tmp_params, empty_sentences, adjusted_train_labels[0], '')
        # print("template_prompt = ", template_prompt)
        # if exclude_nxt:
        #     template_prompt = template_prompt.replace('\n', '    ')
        template_prompt = tokenizer(template_prompt, return_tensors="pt")['input_ids'].to(device)

        # else:
        #     empty_sentences = ['' for each in adjusted_train_labels[0]]
        #     template_prompt, _ = construct_prompt_without_test_emptyanswer(tmp_params, empty_sentences, empty_sentences, '')
        #     print("template_prompt = ", template_prompt)
        #     # if exclude_nxt:
        #     #     template_prompt = template_prompt.replace('\n', '    ')
        #     template_prompt = tokenizer(template_prompt, return_tensors="pt")['input_ids'].to(device)           
    if True:
        tmp_params = {}
        for each in params:
            tmp_params[each] = copy.deepcopy(params[each])
        # content_prompt = tokenizer(' '.join(train) + '\n\n', return_tensors="pt")['input_ids'].to(device)   
        # 之前这样写是因为answer一定会被去掉，现在不需要了，就这样就可以。
        # tmp_params["prompt_prefix"] = ''
        if isinstance(adjusted_train_sentences[0][0], dict):
            empty_sentences = [{'hypothesis':'', 'premise': ''} for each in adjusted_train_labels[0]]
            #  should not have this
            # tmp_params["q_prefix"] = ''
        else:
            empty_sentences = ['' for each in adjusted_train_labels[0]]
        # here use the original params to exclude the instruction tokens.
        # no it should not be excluded since the following code do that.
        # also this setting should include the answer
        # if include_answer, then the anti mask should not include the answer

        # use this to include : into content words
        # if include_colon:
        # tmp_params["q_prefix"] = tmp_params["q_prefix"].split(':')[0] + ' '
        # tmp_params["a_prefix"] = tmp_params["a_prefix"].split(':')[0] + ' '
        # if include_answer:
        #     new_template_prompt, _ = construct_prompt_without_test_emptyanswer(tmp_params, empty_sentences, ['' for each in adjusted_train_labels[0]], '')
        # else:
        new_template_prompt, _ = construct_prompt_without_test(tmp_params, empty_sentences, adjusted_train_labels[0], '')
        # print("anti_content_prompt = ", new_template_prompt)
        new_template_prompt = tokenizer(new_template_prompt, return_tensors="pt")['input_ids'].to(device)
        anti_content_prompt = new_template_prompt
    return punctuation_prompt, template_prompt, anti_content_prompt
# import deepcopy
def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    # model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    print("num_new_tokens = ", num_new_tokens)
    return num_new_tokens

def load_model_lora(base_model, device, lora_weights, compression, compression_length, use_partial_mask):
    config = AutoConfig.from_pretrained(
        base_model,
        cache_dir='.cache',
    )
    if compression:
        config.compression_size = compression_length
        config.use_partial_mask = use_partial_mask
        print('use_partial_mask = ', use_partial_mask)
        
    if device == "cuda":
        if compression:
            model = LlamaForCompressionCausalLM.from_pretrained(
                base_model,
                load_in_8bit=False,
                torch_dtype=torch.float16,
                device_map="auto",
                config=config,
            )
        else:
            model = LlamaForCausalLM.from_pretrained(
                base_model,
                load_in_8bit=False,
                torch_dtype=torch.float16,
                device_map="auto",
            )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            torch_dtype=torch.float16,
        )
    elif device == "mps":
        if compression:
            model = LlamaForCompressionCausalLM.from_pretrained(
                base_model,
                device_map={"": device},
                torch_dtype=torch.float16,
                config=config,
            )
        else:
            model = LlamaForCausalLM.from_pretrained(
                base_model,
                device_map={"": device},
                torch_dtype=torch.float16,
            )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
            torch_dtype=torch.float16,
        )
    else:
        if compression:
            model = LlamaForCompressionCausalLM.from_pretrained(
                base_model, device_map={"": device}, low_cpu_mem_usage=True,
                config=config,
            )
        else:
            model = LlamaForCausalLM.from_pretrained(
                base_model, device_map={"": device}, low_cpu_mem_usage=True
            )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
        )
    return model

def main(model, lora_weight, dataset, num_seeds, all_shots, subsample_test_set, compression, compression_length, use_partial_mask, compression_without_input, with_prompt_text, compression_without_prompt_text, with_sequence_order, compression_token_initialization,  load_in_8bit, only_first_half_layers, rank_correlated):
    test_inference = True

    print("lora_weight = ", lora_weight)
    print("model = ", model)
    print("dataset = ", dataset)
    print("with_prompt_text = ", with_prompt_text)
    print("compression_without_prompt_text = ", compression_without_prompt_text)




    if test_inference:
        if torch.cuda.is_available():
            device = "cuda"
        else:
            device = "cpu"
        base_model = 'decapoda-research/llama-7b-hf' if lora_weight is not None else model
        # base_model = model

        

        # model = LlamaForCausalLM.from_pretrained(
        #     base_model, device_map={"": device}, low_cpu_mem_usage=True, cache_dir='.cache',
        # )
        # model = PeftModel.from_pretrained(
        #     model,
        #     lora_weights,
        #     device_map={"": device},
        # )
        if model == 'decapoda-research/llama-7b-hf':
            tokenizer = LlamaTokenizer.from_pretrained('baffo32/decapoda-research-llama-7B-hf', cache_dir='.cache')
        elif model == 'openlm-research/open_llama_3b':
            tokenizer = LlamaTokenizer.from_pretrained('openlm-research/open_llama_3b',local_files_only=True,  cache_dir='.cache')
        elif model == 'decapoda-research/llama-13b-hf':
            tokenizer = LlamaTokenizer.from_pretrained('dfurman/LLaMA-13B', cache_dir='.cache')
        elif model == 'decapoda-research/llama-30b-hf':
            tokenizer = LlamaTokenizer.from_pretrained('TheBloke/llama-30b-supercot-SuperHOT-8K-fp16', cache_dir='.cache')
        else:
            tokenizer = LlamaTokenizer.from_pretrained(base_model,local_files_only=True,  cache_dir='.cache')
      
        # if lora_weight is None:

        IGNORE_INDEX = -100
        DEFAULT_PAD_TOKEN = "[PAD]"
        DEFAULT_EOS_TOKEN = "</s>"
        DEFAULT_BOS_TOKEN = "<s>"
        DEFAULT_UNK_TOKEN = "<unk>"
        special_tokens_dict = dict()
        if tokenizer.pad_token is None:
            special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
        if tokenizer.eos_token is None:
            special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
        if tokenizer.bos_token is None:
            special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
        if tokenizer.unk_token is None:
            special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

        
        num_new_tokens = smart_tokenizer_and_embedding_resize(
            special_tokens_dict=special_tokens_dict,
            tokenizer=tokenizer,
            # model=model,
        )

        if lora_weight is not None:
            model = load_model_lora(base_model, device, lora_weight, compression, compression_length, use_partial_mask)
                        # unwind broken decapoda-research config
            # model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
            # model.config.bos_token_id = 1
            # model.config.eos_token_id = 2

            print("model = ", model)
        else:
            config = transformers.AutoConfig.from_pretrained(
                base_model,
                cache_dir='.cache',
                local_files_only=True,
            )
            # if model == 'decapoda-research/llama-7b-hf':
            #     config = transformers.AutoConfig.from_pretrained(
            #         'decapoda-research/llama-7b-hf',
            #         cache_dir='.cache',
            #     )
            # else:
            #     config = transformers.AutoConfig.from_pretrained(
            #         'openlm-research/open_llama_3b',
            #         cache_dir='.cache',
            #     )
            config.compression_token_initialization = compression_token_initialization
            if compression_token_initialization:
                config.initialize_ids = tokenizer("Article: N/A \n\n Answer: N/A \n\n", return_tensors="pt")['input_ids'].tolist()
            if compression: 
                
                config.vocab_size += num_new_tokens
                config.compression_size = compression_length
                config.use_partial_mask = use_partial_mask
                print("use_partial_mask = ", use_partial_mask)

                # if '13b' in base_model:
                model = LlamaForCompressionCausalLM.from_pretrained(
                    base_model,
                    config=config,
                    load_in_8bit=load_in_8bit,
                    trust_remote_code=True,
                    # tie_weights=True,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    cache_dir='.cache',
                )   
                # else:
                #     model = LlamaForCompressionCausalLM.from_pretrained(
                #         base_model,
                #         config=config,
                #         # load_in_8bit=load_8bit,
                #         trust_remote_code=True,
                #         # tie_weights=True,
                #         torch_dtype=torch.float16,
                #         device_map="auto",
                #         cache_dir='.cache',
                #     )                
            else:
                model = LlamaForCausalLM.from_pretrained(
                    base_model,
                    config=config,
                    # load_in_8bit=load_8bit,
                    # tie_weights=True,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    cache_dir='.cache',
                    local_files_only=True,
                )
            
        # num_new_tokens = smart_tokenizer_and_embedding_resize(
        #     special_tokens_dict=special_tokens_dict,
        #     tokenizer=tokenizer,
        #     # model=model,
        # )
        temperature=0.8
        top_p=0.75
        top_k=40
        num_beams=4
        max_new_tokens=2 if compression else 1
        generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            # use_cache=True,
        )

    # dataset_name = ['agnews', 'atis', 'cb', 'dbpedia', 'lama', 'rte', 'slot-movies', 'sst2', 'trec']

     


    # if dataset == 'atis':
    model.eval()


    
    # Different ids got different training set, so we can not just mix them up.
    if dataset == "lama":
        all_lamas = [1001,101,103,106,108,127,1303,131,136,1376,138,140,1412,159,17,176,178,19,190,20,264,27,276,279,30,31,36,361,364,37,39,407,413,449,463,47,495,527,530,740,937]
        all_params = []
        for which_lama in all_lamas:
            # p = deepcopy(default_params)
            p = {}
            p['dataset'] = f"lama_{which_lama}"
            all_params.append(p)
        
        correct = 0
        test_samples_num = 0
        samples_num = 0
        for param_index, params in enumerate(all_params):
            orig_train_sentences, orig_train_labels, orig_test_sentences, orig_test_labels = load_dataset(params)
            
            np.random.seed(num_seeds)
            # AgNews 7600
            # samples_num = subsample_test_set
            few_shot = all_shots
            if test_samples_num == 0:
                test_sentences, test_labels = orig_test_sentences, orig_test_labels
                samples_num += len(orig_test_labels)

            else:
                test_sentences, test_labels = random_sampling(orig_test_sentences, orig_test_labels, samples_num)

            train_sentences, train_labels = random_sampling(orig_train_sentences, orig_train_labels, few_shot)
            # print("----------Train-----------")
            # print(orig_train_sentences[:5], orig_train_labels[:5])
            # print("----------Test-----------")
            # print(test_sentences[0], test_labels[0])



            # print('----------Prompt--------------')

            for i, each_test in enumerate(test_sentences):
                # print("i = ", i, end='\r')
                prompt = construct_prompt(params, train_sentences, train_labels, each_test)
                # print("constructed prompt = ", prompt)


                if test_inference:
                    # prompt = 'The highest mountain in the world is '
                    inputs = tokenizer(prompt, return_tensors="pt")
                    input_ids = inputs["input_ids"].to(device)
                    # generate_params = {
                    #     "input_ids": input_ids,
                    #     "generation_config": generation_config,
                    #     "return_dict_in_generate": True,
                    #     "output_scores": True,
                    #     "max_new_tokens": max_new_tokens,
                    # }
                    # print("yesyesyes???")
                    with torch.no_grad():
                        if not compression:
                            generation_output = model.generate(
                                input_ids=input_ids,
                                generation_config=generation_config,
                                return_dict_in_generate=True,
                                output_scores=True,
                                max_new_tokens=max_new_tokens,
                            )
                        else:
                            generation_output = model.compression_generate(
                                input_ids=input_ids,
                                generation_config=generation_config,
                                return_dict_in_generate=True,
                                output_scores=True,
                                max_new_tokens=max_new_tokens,
                            )

                    s = generation_output.sequences[0]
                    output = tokenizer.decode(s)

                    # answer = output.split()[-1]
                    answer = output.split(':')[-1].strip()
                    # print("------------Model Output------------")
                    
                    # print("label = ", params['label_dict'][test_labels[i]])
                    # print("label = ", test_labels[i])
                    # print("generation_output = ", generation_output)
                    # print("output = ", output)
                    # if i > 20: 
                    #     exit()
                    if dataset in ['cb', 'rte', 'dbpedia', 'sst2', 'trec']:
                        if answer in params['inv_label_dict'].keys() and test_labels[i] == params['inv_label_dict'][answer]:
                            correct += 1
                    else:
                        if answer.startswith(test_labels[i]) or test_labels[i].startswith(answer):
                            correct += 1
        print("model = ", model)
        print("dataset = ", dataset)
        print("num_seeds = ", num_seeds)
        print("all_shots = ", all_shots)
        print("ACC = ", correct / samples_num)

    # if dataset == "slot-movies":
    else:
        params = {
            'dataset': dataset,
        }
        orig_train_sentences, orig_train_labels, orig_test_sentences, orig_test_labels = load_dataset(params)

    # train_sentences = orig_train_sentences[:3]
    # train_labels = orig_train_labels[:3]
    # test_sentences = orig_test_sentences[0]
    # test_labels = orig_test_labels[0]


        np.random.seed(num_seeds)
        # AgNews 7600
        few_shot = all_shots
        if subsample_test_set == 0:
            test_sentences, test_labels = orig_test_sentences, orig_test_labels
            samples_num = len(orig_test_labels)
        else:
            samples_num = min(subsample_test_set, len(orig_test_labels))
            test_sentences, test_labels = random_sampling(orig_test_sentences, orig_test_labels, samples_num)


        train_sentences, train_labels = random_sampling(orig_train_sentences, orig_train_labels, few_shot)
        print("----------Train-----------")
        print(train_sentences[:8],train_labels[:8])

 
        # easier_train_sentences, easier_train_labels, easier_idx = construct_easier_order(train_sentences, train_labels)
        # harder_train_sentences, harder_train_labels, harder_idx = construct_harder_order(train_sentences, train_labels)
        # adjusted_train_sentences = [easier_train_sentences, train_sentences, harder_train_sentences]
        # adjusted_train_labels = [easier_train_labels, train_labels, harder_train_labels]

        def generate_permutations(n):
            # 生成1到n的数字列表
            nums = list(range(1, n+1))
            # 使用itertools.permutations生成全排列
            permutations = list(itertools.permutations(nums))
            return permutations

        all_rank = generate_permutations(few_shot)
        adjusted_train_sentences = [[train_sentences[i - 1] for i in each_rank] for each_rank in all_rank]
        adjusted_train_labels = [[train_labels[i - 1] for i in each_rank] for each_rank in all_rank]



        

        # count = 3
        # while count > 0:
        count = 0
        # best_res = 0
        # best_rank = None
        # res_list = []
        final_res = {}

        print(params)
        # exit()

        punctuation_prompt, template_prompt, anti_content_prompt = acquire_different_type_text(params, adjusted_train_sentences, adjusted_train_labels, tokenizer, device)
        print("punctuation_prompt = ", punctuation_prompt)
        print("template_prompt = ", template_prompt)
        print("anti_content_prompt = ", anti_content_prompt)


        all_res = []
        for tmp_train_sentences, tmp_train_labels in zip(adjusted_train_sentences, adjusted_train_labels):
            correct_single = 0
            # print("prompt = ", prompt)
            # for single sentence
            for i, each_test in enumerate(test_sentences):
                # params['prompt_prefix'] = origin_prefix
                # print("origin_prefix = ", or)
                prompt = construct_prompt(params, tmp_train_sentences, tmp_train_labels, each_test)
                if test_inference:
                    inputs = tokenizer(prompt, return_tensors="pt")
                    input_ids = inputs["input_ids"].to(device)

                    with torch.no_grad():
                        generation_output = model.generate(
                            input_ids=input_ids,
                            generation_config=generation_config,
                            return_dict_in_generate=True,
                            output_scores=True,
                            max_new_tokens=max_new_tokens,
                        )

                    s = generation_output.sequences[0]
                    s = s.masked_fill(s.eq(-99), 0)
                    output = tokenizer.decode(s)
                    answer = output.split(':')[-1].replace("</s>", '').strip()
                    # print("output = ", output)
                    # if i > 20: 
                    #     exit()
                    if dataset in ['cb', 'rte', 'dbpedia', 'sst2', 'trec', 'agnews']:
                        if answer in params['inv_label_dict'].keys() and test_labels[i] == params['inv_label_dict'][answer]:
                            correct_single += 1
                    else:
                        if answer.startswith(test_labels[i]) or test_labels[i].startswith(answer):
                            correct_single += 1
                # pass


            # print("model = ", model)
            # print("dataset = ", dataset)
            # print("num_seeds = ", num_seeds)
            # print("all_shots = ", all_shots)
            # print("single_sentence = ", only_single_sentence)
            # print("tmp_rank = ", all_rank[count])
            # print("ACC_single = ", correct_single / samples_num)
            all_res.append(correct_single / samples_num)
            count += 1
        # for each in all_res:
        #     print(each)
        print("all_res = ", all_res)
        arr = np.array(all_res)

        # 使用广播计算每对元素的相减绝对值
        abs_diff_matrix = np.abs(arr[:, np.newaxis] - arr)

        count = 0
        tokens = [None for i in range(few_shot)]
        for tmp_train_sentences, tmp_train_labels in zip(adjusted_train_sentences, adjusted_train_labels):
        # print('----------Prompt--------------')
            correct = 0
            # with torch.autocast("cuda"):
            # tmp_train_sentences = adjusted_train_sentences[count - 1]
            # tmp_train_labels = adjusted_train_labels[count - 1]
            
            # for i, each_test in enumerate(test_sentences):
                # print("i = ", i, end='\r')
                # if compression_without_input:
                #     # this is the test instance input
                #     if compression_without_prompt_text:
                #         prompt, input_prompt_text, instance_prompt = construct_prompt_instance_prompt_text(params, train_sentences, train_labels, each_test)
                #         instance_prompt = input_prompt_text + instance_prompt
                #     else:
                #         prompt, instance_prompt = construct_prompt_without_test(params, train_sentences, train_labels, each_test)



                #     # prompt = construct_prompt(params, train_sentences, train_labels, each_test)
                # else:
                #     prompt = construct_prompt(params, train_sentences, train_labels, each_test)
                
                # print("constructed prompt = ", prompt)
            prompt, _ = construct_prompt_without_test(params, tmp_train_sentences, tmp_train_labels, '')
            # print("prompt = ", prompt)



            if test_inference:
                # prompt = 'The highest mountain in the world is '
                inputs = tokenizer(prompt, return_tensors="pt")
                # used to give a context to the answer.

                    # print("after compression tokens inputs = ", inputs)
                input_ids = inputs["input_ids"].to(device)

                with torch.no_grad():
                    # generation_output = model.generate(
                    #     input_ids=input_ids,
                    #     generation_config=generation_config,
                    #     return_dict_in_generate=True,
                    #     output_scores=True,
                    #     max_new_tokens=max_new_tokens,
                    # )
                    output = model(input_ids, output_hidden_states=True)
                    hidden_states = output.hidden_states
                    # for layer, hidden_state in enumerate(hidden_states):
                        # print(f"Layer {layer} representation:")
                        # print(hidden_state.size())
                        # print(hidden_state)
                        
                        # print()
                        # exit()
                
                # print("input_ids = ", input_ids.size())
                # print(input_ids)
                # 13:
                # 2: instruction
                # every 3: one sample 
                # 依次取出每一个sample每一层的representation
                # 然后存在一个文件里，变成list然后json就可以
                # 现在所有的rep都看，到时候做增强的时候需要把instruction的也加进来。
                # 写代码的时候，还是需要qkv的结果，emmm
                # 还是写一个representation 转kv的吧，这样的话会好弄一点，至少不用在generation_utils里边做很多修改？。
                all_mask = input_ids.eq(13).long()
                all_mask = torch.cumsum(all_mask, dim=-1)
                # print("all_mask = ", all_mask)
                all_mask =torch.cat((all_mask[:, :1], all_mask[:, :-1]), dim=-1)
                # print("all_mask2 = ", all_mask)
                # print("all_mask.gt = ", all_mask.ge(2))
                # print("all_mask.le = ", all_mask.lt(5))



                # instruction_mask = all_mask.lt(2)
                mask_list = []
                for i in range(few_shot):
                    if params['dataset'] == 'cb' or params['dataset'] == 'rte':
                        tmp_mask = all_mask.ge(2 + (i * 4)) & all_mask.lt(2 + (i + 1) * 4)
                    else:
                        tmp_mask = all_mask.ge(2 + (i * 3)) & all_mask.lt(2 + (i + 1) * 3)
                    # tmp_mask = tmp_mask[:-1]
                    mask_list.append(tmp_mask)
                # print("mask_list = ", len(mask_list))
                # print(mask_list)

                # instruction_rep = torch.masked_select(hidden_states[0], instruction_mask.unsqueeze(-1).repeat(1, 1, hidden_state.size(-1))).view(1, -1,hidden_state.size(-1))

                # instruction_rep = instruction_rep.mean(dim=-2)
                # instruction_rep = instruction_rep.tolist()
                # 先不管instruction_rep
                # print("instruction_rep = ", instruction_rep.size())
                # print(instruction_rep)


                result_list = [[] for i in range(len(mask_list))]



                for demo, demo_mask in enumerate(mask_list):
                    real_demo = all_rank[count][demo]
                    # print("real_demo = ", real_demo)
                    if count == 0:
                        tokens[real_demo - 1] = torch.masked_select(input_ids, demo_mask)
                    for layer, hidden_state in enumerate(hidden_states):
                        sentence_rep =  torch.masked_select(hidden_state, demo_mask.unsqueeze(-1).repeat(1, 1, hidden_state.size(-1))).view(1, -1,hidden_state.size(-1))
                        # result_list[real_demo - 1].append(sentence_rep.mean(dim=-2).tolist())
                        result_list[real_demo - 1].append(sentence_rep)
                
                # print("result_list = ", result_list)
                # 存成一个文件，然后做一个py处理这个文件，算一个相似度矩阵。
                # 在读的时候做一个all permutation转字符串的操作吧，这样应该就可以了。
                final_res[str(all_rank[count])] = result_list

            # print("model = ", model)
            # print("dataset = ", dataset)
            # print("num_seeds = ", num_seeds)
            # print("all_shots = ", all_shots)
            # print("tmp_rank = ", all_rank[count])
            # print("ACC = ", correct / samples_num)
            # res_list.append(correct / samples_num)
            # if best_res < correct / samples_num:
            #     best_res = correct / samples_num
            #     best_rank = all_rank[count]
            count += 1
        

        # with open(res_file, 'r') as reader:
        #     res = reader.readlines()
        # res = [float(each.strip()) for each in res]
        # key: permutation name.
        # list:
        #   demos
        #   layers
        #   dims
        # print("repsentations = ", representations)
        # for each in representations.keys():
        #     print(each)
        rank_keys = list(final_res.keys())
        # sorted_rank_keys = [x for _, x in sorted(zip(res, rank_keys))]
        # for each in sorted_rank_keys:
        #     print(each)
        # print("res = ", res)
        demonstration_num = len(final_res[rank_keys[0]])
        # print("tensor = ", final_res[rank_keys[0]][0][0].size())
        # print("tokens = ", tokens)
        # exit()
        # print("tensor = ", final_res[rank_keys[0]][0].size())
        layers_num = len(final_res[rank_keys[0]][0])
        # token_num = len(final_res[rank_keys[0]][0][0][0])
        dim_num = len(final_res[rank_keys[0]][0][0][0][0])
        # 4 27 2400
        # 27 = 第一层输入加所有层输出。
        # print(demonstration_num, layers_num, token_num, dim_num)
        # exit()

        # print(representations[rank_keys[0]][0][-1][0])

        # print(representations[rank_keys[-1]][0][-1][0])
        # demos * Similar(permutations * permutation) * layer_nums
        # cosine_similarity = 1 - distance.cosine(representations[rank_keys[0]][0][0][0], representations[rank_keys[-1]][0][0][0])
        # print("similarity_demo = ", cosine_similarity)

        # cosine_similarity = 1 - distance.cosine(representations[rank_keys[0]][0][-1][0], representations[rank_keys[-1]][0][-1][0])
        # print("similarity_demo = ", cosine_similarity)

        # Don't need the workbook anymore, we are trying to know the exact average results of the similarity.
        # workbook = Workbook()
        # default_sheet = workbook.active
        # default_sheet.title = "Default Sheet"

        # big_difference = [[] for each in ]
        cos = nn.CosineSimilarity(dim=0, eps=1e-9) 
        data = []
        # big_difference_data = []
        # cur_sheet = workbook.create_sheet(title="Layer " + str(k))
        cnt = 1
        template_difference = []
        stopword_difference = []
        content_difference = []
        template_difference_correlated = []
        stopword_difference_correlated = []
        content_difference_correlated = []
        template_pearson = []
        stopword_pearson = []
        content_pearson = []
        template_spearman = []
        stopword_spearman = []
        content_spearman = []
        for demo in range(demonstration_num):
            token_num = tokens[demo].size(0)
            # print("token_num = ", token_num)
            # print("tokens = ", tokens[demo])
            differences = {}
            detailed_differences = {}
            # all_data = []
            for t in range(token_num):
            # 这里决定是算相邻的相似度还是按照表现排序的相似度
                # print(tokens[demo][t])
                # print(tokenizer.convert_ids_to_tokens(tokens[demo][t].unsqueeze(-1)))
                # exit()
                # token = tokenizer.convert_ids_to_tokens(tokens[demo][t].unsqueeze(-1))
                # print("token = ", token)
                # data.append(tokenizer.convert_ids_to_tokens(tokens[demo][t].unsqueeze(-1)))
                
                tmp2_data = []
                if only_first_half_layers:
                    start_layer = layers_num // 4
                    tmp_layers_num = layers_num // 2
                else:
                    start_layer = 1
                    tmp_layers_num = layers_num 
                for k in range(start_layer, tmp_layers_num):
                # data.append()
            # for i_key in sorted_rank_keys:
            # 这里用的是相邻的相似度，无敌了。
            # 其实也无所谓，直接用最大值跟平均值吧
                    # print("k = ", k)
                    tmp1_data = []
                    for i_key in rank_keys:
                        tmp_data = []
                        for j_key in rank_keys:
                        # for j_key in sorted_rank_keys:
                            # print(final_res[i_key][demo][k][0][t].size())
                            # print(final_res[j_key][demo][k][0][t].size())
                            # exit()
                            cosine_similarity = (1 - cos(final_res[i_key][demo][k][0][t], final_res[j_key][demo][k][0][t])) / 2
                            # cosine_similarity = 1 - (1 + cos(final_res[i_key][demo][k][0][t], final_res[j_key][demo][k][0][t])) / 2
                            # print('{:.3f}'.format(cosine_similarity), end=' ')
                            tmp_data.append(cosine_similarity.item())
                        # cka_similarity = feature_space_linear_cka(representations[i_key][demo][-1], representations[j_key][demo][-1]) 
                        # print(cka_similarity, end=' ')
                        # break
                        # print()
                        tmp1_data.append(tmp_data)
                        # only one difference is not enough
                        # break
                    # print("temp1_data = ", tmp1_data)
                    tmp2_data.append(tmp1_data)
                    # exit()
                tmp2_data = np.array(tmp2_data)
                # 这里应该还是一个某个demonstration的第几个字
                mean = tmp2_data.mean()
                max0 = tmp2_data.max()
                std = tmp2_data.std()
                # 这里如果是两个demonstration里边有同样的字就会出问题对吧。
                # 一句话里边有两个一样的也不行，shit
                # 嗷嗷没问题，因为是按照demo来的。
                
                differences[t] = (mean, max0, std)
                detailed_differences[t] = tmp2_data.mean(axis=0)
                # differences[tokens[demo][t]] = (mean, max0, std)
            # exit()
            # print("differences = ", differences.keys())
            template_res = []
            punctuation_res = []
            content_res = []
            template_detail = []
            punctuation_detail = []
            content_detail = []
            # each key is a number
            # print(differences)
            # each key_is a number, the value is the np array. the difference among all the permutations 24 * 24?
            # maybe we need to flatten the matrix?
            # print("detailed_differences = ", detailed_differences)
            # print("abs_diff_matrix = ", abs_diff_matrix)
            # exit()
            # if rank_correlated:
            #     res_pearson = []
            #     res_spearman = []
            #     index = []
            #     for i, each in enumerate(differences.keys()):
            #         index.append(each)
            #         res_pearson.append(np.corrcoef(abs_diff_matrix.flatten(), detailed_differences[each].flatten())[0][1])
            #         res_spearman.append(stats.spearmanr(abs_diff_matrix.flatten(), detailed_differences[each].flatten())[0])
            #     AB_sorted = sorted(zip(index, res_pearson), key=lambda pair: pair[1])
            #     A_sorted_by_B, B_sorted = zip(*AB_sorted)
            #     print("Pearson:")
            #     for i, each in enumerate(A_sorted_by_B):
            #         print(tokenizer.convert_ids_to_tokens(tokens[demo][each].unsqueeze(-1)), ": ", B_sorted[i])

            #     AB_sorted = sorted(zip(index, res_spearman), key=lambda pair: pair[1])
            #     A_sorted_by_B, B_sorted = zip(*AB_sorted)
            #     print("Spearman: ")
            #     for i, each in enumerate(A_sorted_by_B):
            #         print(tokenizer.convert_ids_to_tokens(tokens[demo][each].unsqueeze(-1)), ": ", B_sorted[i])

                # exit()
                # continue
            # flag = 0
            for i, each in enumerate(differences.keys()):
                # print("each = ", tokenizer.convert_ids_to_tokens(tokens[demo][each].unsqueeze(-1)))
                # if flag == 1:
                #     print("punctuation")
                #     flag = 0
                #     continue
                if tokens[demo][each] in template_prompt:
                    template_res.append(differences[each])
                    template_detail.append(detailed_differences[each])
                    # print("template")
                    continue
                if tokens[demo][each] in punctuation_prompt:
                    punctuation_res.append(differences[each])
                    punctuation_detail.append(detailed_differences[each])
                    # print("punctuation")
                    # each_1 = list(differences.keys())[i + 1]
                    # punctuation_res.append(differences[each_1])
                    # flag = 1
                    continue
                content_detail.append(detailed_differences[each])
                content_res.append(differences[each])
                # print("content")
            # print("template_res = ", template_res)
            # print("punctuation_res = ", punctuation_res)
            # print("content_res = ", content_res)



            print("Demo:--------------- ", cnt)
            cnt += 1
            if len(template_res) == 0:
                print("no template for this demo")
                # template_difference.append(None)
                # template_pearson.append(None)
                # template_spearman.append(None)
            else:
                template_avg = 0
                template_max = 0
                template_std = 0
                for each in template_res:
                    template_avg += each[0]
                    template_max += each[1]
                    template_std += each[2]
                template_avg /= len(template_res)
                template_max /= len(template_res)
                template_std /= len(template_res)


                print("template")
                print(template_avg, template_max, template_std)
                template_difference.append(template_avg)



                correlated_factor_avg = 0
                spearmanr_corr = 0
                for each in template_detail:
                    correlated_factor_avg += np.corrcoef(abs_diff_matrix.flatten(), each.flatten())
                    spearmanr_corr_tmp, _ = stats.spearmanr(abs_diff_matrix.flatten(), each.flatten())
                    spearmanr_corr += spearmanr_corr_tmp
                print("correlated_factor = ")
                print(correlated_factor_avg / len(template_detail))
                print(spearmanr_corr / len(template_detail))
                template_pearson.append(correlated_factor_avg[0][1] / len(template_detail))
                template_spearman.append(spearmanr_corr / len(template_detail))


            if len(punctuation_res) == 0:
                print("no punctuation for this demo")
                # stopword_difference.append(None)
                # stopword_pearson.append(None)
                # stopword_spearman.append(None)
            else:
                punctuation_avg = 0
                punctuation_max = 0
                punctuation_std = 0
                for each in punctuation_res:
                    punctuation_avg += each[0]
                    punctuation_max += each[1]
                    punctuation_std += each[2]
                punctuation_avg /= len(punctuation_res)
                punctuation_max /= len(punctuation_res)
                punctuation_std /= len(punctuation_res)
                print("punctuation")
                print(punctuation_avg, punctuation_max, punctuation_std)
                stopword_difference.append(punctuation_avg)





                correlated_factor_avg = 0
                spearmanr_corr = 0
                for each in punctuation_detail:
                    correlated_factor_avg += np.corrcoef(abs_diff_matrix.flatten(), each.flatten())
                    spearmanr_corr_tmp, _ = stats.spearmanr(abs_diff_matrix.flatten(), each.flatten())
                    spearmanr_corr += spearmanr_corr_tmp
                print("correlated_factor = ")
                print(correlated_factor_avg / len(punctuation_detail))
                print(spearmanr_corr / len(punctuation_detail))
                stopword_pearson.append(correlated_factor_avg[0][1] / len(punctuation_detail))
                stopword_spearman.append(spearmanr_corr / len(punctuation_detail))

            if len(content_res) == 0:
                print("no content for this demo")
                # content_difference.append(None)
                # content_pearson.append(None)
                # content_spearman.append(None)
            else:
                content_avg = 0
                content_max = 0
                content_std = 0
                for each in content_res:
                    content_avg += each[0]
                    content_max += each[1]
                    content_std += each[2]
                content_avg /= len(content_res)
                content_max /= len(content_res)
                content_std /= len(content_res)
                print("content")
                print(content_avg, content_max, content_std)  
                content_difference.append(content_avg)



                correlated_factor_avg = 0
                spearmanr_corr = 0
                for each in content_detail:
                    correlated_factor_avg += np.corrcoef(abs_diff_matrix.flatten(), each.flatten())
                    spearmanr_corr_tmp, _ = stats.spearmanr(abs_diff_matrix.flatten(), each.flatten())
                    spearmanr_corr += spearmanr_corr_tmp
                print("correlated_factor = ")
                print(correlated_factor_avg / len(content_detail))
                print(spearmanr_corr / len(content_detail))
            
                content_pearson.append(correlated_factor_avg[0][1] / len(content_detail))
                content_spearman.append(spearmanr_corr / len(content_detail))
            
            if len(content_res) == 0 or len(punctuation_res) == 0 or len(template_res) == 0:
                pass
            else:
                content_difference_correlated.append(content_difference[-1])
                template_difference_correlated.append(template_difference[-1])
                stopword_difference_correlated.append(stopword_difference[-1])

        # print the correlation factor    
        print()
        print("template-stopword-content difference")
        print(template_difference)
        print(stopword_difference)
        print(content_difference)
        
        print(sum(template_difference)/len(template_difference))
        print(sum(stopword_difference)/len(stopword_difference))
        print(sum(content_difference)/len(content_difference))
        print()

        print(sum(template_pearson)/len(template_pearson))
        print(sum(stopword_pearson)/len(stopword_pearson))
        print(sum(content_pearson)/len(content_pearson))
        print()

        print(sum(template_spearman)/len(template_spearman))
        print(sum(stopword_spearman)/len(stopword_spearman))
        print(sum(content_spearman)/len(content_spearman))
        print()

        # we have to delete all the elements which have None.

        print("template tokens <-> content tokens")
        # print("pearson")
        print("pearson and spearman")

        print(np.corrcoef(template_difference_correlated, content_difference_correlated)[0,1])
        corr, p_value = stats.spearmanr(template_difference_correlated, content_difference_correlated)
        # print("spearman")
        print(corr, p_value)

        print("template tokens <-> stopword tokens")
        # print("pearson")
        print("pearson and spearman")
        print(np.corrcoef(template_difference_correlated, stopword_difference_correlated)[0,1])
        corr, p_value = stats.spearmanr(template_difference_correlated, stopword_difference_correlated)
        # print("spearman")
        print(corr, p_value)

        print("content tokens <-> stopword tokens")
        print("pearson and spearman")
        print(np.corrcoef(content_difference_correlated, stopword_difference_correlated)[0,1])
        corr, p_value = stats.spearmanr(content_difference_correlated, stopword_difference_correlated)
        print(corr, p_value)
            # exit()
            # big_difference_data.append(differences)
                # tmp2_data = []
                # tmp1_data = []
                # print(mean, max0, std)

                # exit()
                    # data.append([''])
                    # data.append([''])
                    # data.append([''])
                # big_difference_data.append(data[-1][0])
                # data.append([''])
                # data.append([''])
                # data.append([''])
                # data.append([''])
                # data.append([''])
            # 每个demo中每个词与每个词的区别。
            # 要不就直接看第一个permutation和最后一个permutation的区别？这样不用看结果，并且这样的话理论上区别应该是最大的（1, 2, 3, 4 vs 4, 3, 2, 1）
            # print('data = ', data)
            # exit()
            # for row in data:
                # cur_sheet.append(row)
            # default_sheet.append(big_difference_data)
            # for row in data:

        # workbook.save(res_file.replace('.res', '.full.xlsx'))
        # file_prefix = 'rep_save/all_sent_'
        # file_name = file_prefix + base_model.split('/')[-1] + "_"+ str(num_seeds) + '.json'
        # res_name = file_prefix + base_model.split('/')[-1] + "_"+ str(num_seeds) + '.res'

        # json_data = json.dumps(final_res)
        # with open(file_name, 'w') as file:
        #     file.write(json_data)
        # with open(res_name, 'w') as file:
        #     file.write(' ')
             
        # exit()



                

                # s = generation_output.sequences[0]
                # s = s.masked_fill(s.eq(-99), 0)
                # exit()
                # output = tokenizer.decode(s)

                # answer = output.split()[-1]
                # answer = output.split(':')[-1].replace("</s>", '').strip()
                # if compression:
                #     print("------------Model Output------------")
                    
                #     print("s = ", s)
                #     print("output = ", output)
                #     print("answer = ", answer)
                #     print("label = ", params['label_dict'][test_labels[i]])
                #     print("label = ", test_labels[i])
                #     if i > 20: 
                #         exit()
                    # print("output = ", output)
                    # print("generation_output = ", generation_output)
                    # exit()
                # exit()

                    
            #         if dataset in ['cb', 'rte', 'dbpedia', 'sst2', 'trec', 'agnews']:
            #             if answer in params['inv_label_dict'].keys() and test_labels[i] == params['inv_label_dict'][answer]:
            #                 correct += 1
            #         else:
            #             if answer.startswith(test_labels[i]) or test_labels[i].startswith(answer):
            #                 correct += 1

        # print final results
        # harder_rank, easier_rank, full_rank

        # print("rank_number = ", len(all_rank))
        # print("all_permutation = ", all_rank)


        # sorted_zipped = sorted(zip(res_list, all_rank), key=lambda x: x[0], reverse=True)

        # new_res_list, new_rank = zip(*sorted_zipped)
        # harder_idx = tuple([each + 1 for each in harder_idx])
        # easier_idx = tuple([each + 1 for each in easier_idx])
        # print("harder_idx = ", harder_idx)
        # print("easier_idx = ", easier_idx)
        # print("harder_rank = ", new_rank.index(harder_idx))
        # print("harder_res = ", new_res_list[new_rank.index(harder_idx)])
        # print("easier_rank = ", new_rank.index(easier_idx))
        # print("easier_res = ", new_res_list[new_rank.index(easier_idx)])

        # print("best_res = ", best_res)
        # print("best_permutation = ", best_rank)
        # print("worst_res = ", new_res_list[-1])
        # print("worst_permutation = ", new_rank[-1])

        # print("average_res = ", sum(new_res_list)/len(new_res_list))
        # print("mid_res = ", new_res_list[len(new_res_list) // 2])



        
            # print("")
            # count -= 1
            # change the order of the training sentences
            # combined = list(zip(train_sentences, train_labels))
            # random.shuffle(combined)
            # train_sentences, train_labels = zip(*combined)










if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # required arguments
    parser.add_argument('--model', dest='model', action='store', required=True, help='name of model(s), e.g., GPT2-XL')
    # parser.add_argument('--res_file', dest='res_file', action='store', required=True, help='name of model(s), e.g., GPT2-XL')
    parser.add_argument('--lora_weight', dest='lora_weight', action='store', required=False, default=None, help='name of model(s), e.g., GPT2-XL')
    parser.add_argument('--dataset', dest='dataset', action='store', required=True, help='name of dataset(s), e.g., agnews')
    parser.add_argument('--num_seeds', dest='num_seeds', action='store', required=True, help='num seeds for the training set', type=int)
    # other arguments
    parser.add_argument('--subsample_test_set', dest='subsample_test_set', action='store', required=False, type=int,
                        default=None, help='size of test set to use to speed up eval. None means using all test set')
    # parser.add_argument('--api_num_log_prob', dest='api_num_log_prob', action='store', required=False, type=int,
    #                     default=100, help='number of top tokens to ask for when querying the model. Capped at 100 for OpenAI GPT-3 API')
    # parser.add_argument('--bs', dest='bs', action='store', required=False, type=int, default=None,
    #                     help='batch size for model queries. For OpenAI API, capped at 20. For local running, set this to max out your GPU memory.')
    # flags
    # parser.add_argument('--use_saved_results', dest='use_saved_results', action='store_const', const=True, default=False,
    #                     help='whether to load the results from pickle files and not run the model')
    parser.add_argument('--compression', dest='compression', action='store_const', const=True, default=False,
                        help='whether to use the compression generation mode')
    parser.add_argument('--load_in_8bit', dest='load_in_8bit', action='store_const', const=True, default=False,
                        help='whether to use the compression generation mode')

    parser.add_argument('--compression_without_input', dest='compression_without_input', action='store_const', const=True, default=False,
                        help='whether to use the compression generation mode')
    parser.add_argument('--use_partial_mask', dest='use_partial_mask', action='store_const', const=True, default=False,
                        help='whether to use the partial_mask')
    parser.add_argument('--with_prompt_text', dest='with_prompt_text', action='store_const', const=True, default=False,
                        help='whether to use the prompt text')
    parser.add_argument('--compression_without_prompt_text', dest='compression_without_prompt_text', action='store_const', const=True, default=False,
                        help='whether to use the prompt text')
    parser.add_argument('--with_sequence_order', dest='with_sequence_order', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    parser.add_argument('--compression_token_initialization', dest='compression_token_initialization', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    parser.add_argument('--only_first_half_layers', dest='only_first_half_layers', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    parser.add_argument('--rank_correlated', dest='rank_correlated', action='store_const', const=True, default=False,
                        help='whether to test the sequence order of the model')
    # parser.add_argument('--with_prompt_text', dest='with_prompt_text', action='store_const', const=True, default=False,
    #                     help='whether to use the prompt text')
    parser.add_argument('--compression_length', dest='compression_length', action='store', required=False, default=None, help='num of compression_tokens', type=int)
    parser.add_argument('--all_shots', dest='all_shots', action='store', required=True, help='num training examples to use', type=int)
    # compression_token_initialization

    args = parser.parse_args()
    args = vars(args)

    # simple processing
    # def convert_to_list(items, is_int=False):
    #     if is_int:
    #         return [int(s.strip()) for s in items.split(",")]
    #     else:
    #         return [s.strip() for s in items.split(",")]

    # args['models'] = convert_to_list(args['models'])
    # args['datasets'] = convert_to_list(args['datasets'])
    # args['all_shots'] = convert_to_list(args['all_shots'], is_int=True)

    main(**args)