from data_utils import load_dataset
from utils import construct_prompt, random_sampling, construct_prompt_without_test, construct_prompt_instance_prompt_text
import numpy as np
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, LlamaForCompressionCausalLM, AutoConfig
import argparse
from typing import Dict, Optional, Sequence
import itertools
import json
import random
from scipy.spatial import distance
from openpyxl import Workbook

import numpy as np



def main(res_file):
    # with open(rep_file, 'r') as reader:
    #     json_data = reader.read()
    # representations = json.loads(json_data)



    with open(res_file, 'r') as reader:
        res = reader.readlines()
    res = [float(each.strip()) for each in res]
    # key: permutation name.
    # list:
    #   demos
    #   layers
    #   dims
    # print("repsentations = ", representations)
    # for each in representations.keys():
    #     print(each)



    def generate_permutations(n):
        # 生成1到n的数字列表
        nums = list(range(1, n+1))
        # 使用itertools.permutations生成全排列
        all_permutations = []
        for i in range(1, n + 1):
            permutations = list(itertools.permutations(nums, i))
            all_permutations += permutations
        return all_permutations
    
    rank_keys = generate_permutations(4)
    print("rank_keys = ", rank_keys)
    print('res = ', res)


    new_res = {}
    for i, each in enumerate(rank_keys):
        new_res[each] = res[i]
    
    
    print("better permutation + one same demo is still better?")
    total = 0
    true = 0
    for each_demo in rank_keys:
        # print(each_demo)
        if len(each_demo) > 1:
            continue
        # print("this is a demo")
        for each_p1 in rank_keys:
            for each_p2 in rank_keys:
                # print("p1 = ", each_p1)
                # print("p2 = ", each_p2)
                if each_p1 == each_p2:
                    continue
                if each_p1 + each_demo in rank_keys and each_p2 + each_demo in rank_keys:
                    if (new_res[each_p1] > new_res[each_p2]) == (new_res[each_p1 + each_demo] > new_res[each_p2 + each_demo]):
                        true += 1
                    total += 1
    print("acc = ", true / total)


            
    print("one same demo + better permutation is better?")
    total = 0
    true = 0
    for each_demo in rank_keys:
        if len(each_demo) > 1:
            continue
        for each_p1 in rank_keys:
            for each_p2 in rank_keys:
                if each_p1 == each_p2:
                    continue
                if each_demo + each_p1 in rank_keys and each_demo + each_p2  in rank_keys:
                    if (new_res[each_p1] > new_res[each_p2]) == (new_res[each_demo + each_p1] > new_res[each_demo + each_p2]):
                        true += 1
                    total += 1
    print("acc = ", true / total)

    

    print("same permutation + better demo is better?")
    total = 0
    true = 0
    for each_permutation in rank_keys:
        if len(each_permutation) == 1 or len(each_permutation) == 4:
            continue
        for each_d1 in rank_keys:
            for each_d2 in rank_keys:
                if len(each_d1) > 1 or len(each_d2) > 1:
                    continue
                if each_d1 == each_d2:
                    continue
                if each_permutation + each_d1 in rank_keys and each_permutation + each_d2  in rank_keys:
                    if (new_res[each_d1] > new_res[each_d2]) == (new_res[each_permutation + each_d1] > new_res[each_permutation + each_d2]):
                        true += 1
                    total += 1
    print("acc = ", true / total)

    


    print("better demo + same permutation is better?")
    total = 0
    true = 0
    for each_permutation in rank_keys:
        if len(each_permutation) == 1 or len(each_permutation) == 4:
            continue
        for each_d1 in rank_keys:
            for each_d2 in rank_keys:
                if len(each_d1) > 1 or len(each_d2) > 1:
                    continue
                if each_d1 == each_d2:
                    continue
                if each_d1 + each_permutation  in rank_keys and each_d2 + each_permutation in rank_keys:
                    if (new_res[each_d1] > new_res[each_d2]) == (new_res[each_d1 + each_permutation] > new_res[each_d2 + each_permutation]):
                        true += 1
                    total += 1
    print("acc = ", true / total)




    print("same permutation + better permutation is better?")
    total = 0
    true = 0
    for each_permutation in rank_keys:
        # if len(each_permutation) == 1 or len(each_permutation) == 4:
        #     continue
        for each_p1 in rank_keys:
            for each_p2 in rank_keys:
                if each_p1 == each_p2:
                    continue
                if each_permutation + each_p1 in rank_keys and each_permutation + each_p2  in rank_keys:
                    if (new_res[each_p1] > new_res[each_p2]) == (new_res[each_permutation + each_p1] > new_res[each_permutation + each_p2]):
                        true += 1
                    total += 1
    print("acc = ", true / total)




    print("better permutation + same permutation is better?")
    total = 0
    true = 0
    for each_permutation in rank_keys:
        # if len(each_permutation) == 1 or len(each_permutation) == 4:
        #     continue
        for each_p1 in rank_keys:
            for each_p2 in rank_keys:
                if each_p1 == each_p2:
                    continue
                if each_p1 + each_permutation  in rank_keys and each_p2 + each_permutation in rank_keys:
                    if (new_res[each_p1] > new_res[each_p2]) == (new_res[each_p1 + each_permutation] > new_res[each_p2 + each_permutation]):
                        true += 1
                    total += 1
    print("acc = ", true / total)



    print("One permutation + one demo is better?")
    total = 0
    true = 0
    for each_demo in rank_keys:
        if len(each_demo) != 1:
            continue
        for each_p1 in rank_keys:
            if each_p1 + each_demo  in rank_keys:
                if (new_res[each_p1 + each_demo] > new_res[each_p1]):
                    true += 1
                total += 1
    print("acc = ", true / total)


    print("One demo + one permutation  is better?")
    total = 0
    true = 0
    for each_demo in rank_keys:
        if len(each_demo) != 1:
            continue
        for each_p1 in rank_keys:
            if each_demo + each_p1  in rank_keys:
                if (new_res[each_demo + each_p1] > new_res[each_p1]):
                    true += 1
                total += 1
    print("acc = ", true / total)
    # 看最好的和最坏的




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # required arguments
    # parser.add_argument('--rep_file', dest='rep_file', action='store', required=True, help='name of model(s), e.g., GPT2-XL')
    parser.add_argument('--res_file', dest='res_file', action='store', required=True, help='name of model(s), e.g., GPT2-XL')

    args = parser.parse_args()
    args = vars(args)

    # simple processing
    # def convert_to_list(items, is_int=False):
    #     if is_int:
    #         return [int(s.strip()) for s in items.split(",")]
    #     else:
    #         return [s.strip() for s in items.split(",")]

    # args['models'] = convert_to_list(args['models'])
    # args['all_shots'] = convert_to_list(args['all_shots'], is_int=True)

    main(**args)