# 使用部署的模型测试大模型单次输出能力 通过查词表的方式    大批测试
import random
import json
import concurrent.futures
import numpy as np
import requests
import argparse
from transformers import AutoTokenizer
# 生成乱序词表
import random
import string
from config import model_load_config
# 创建 ArgumentParser 对象
parser = argparse.ArgumentParser(description='accInEachStrlen')

# 添加参数
parser.add_argument('--strlen', type=int, help='len of string')
parser.add_argument('--seed', type=int, help='the seed of current exp')
parser.add_argument('--times', type=int, help='times of experiment') 
parser.add_argument('--print_box', type=bool,default=False, help='if print box') 
parser.add_argument('--random_dict', type=int,default=0, help='if random_dict and the size of random dict') 
parser.add_argument('--tag_name', type=str,default="", help='tag name of model')
parser.add_argument('--suffix', type=str,default="", help='suffix')
parser.add_argument('--dict_type', type=str,default=[], nargs="+",help='type of dict') 
parser.add_argument('--lenmin', type=int,default=1, help='lenmin of random dict') 
parser.add_argument('--lenmax', type=int,default=4, help='lenmax of random dict') 
parser.add_argument('--value_len', type=int,default=0, help='value_len of random dict') 
# 解析参数
args = parser.parse_args()

model_config=model_load_config[args.tag_name]

modelname=model_config["model_name"]
tag_name=model_config["tag_name"]
url  =model_config["url"]
mode=model_config["mode"]

print_box=args.print_box
times=args.times
seed=args.seed
string_len = args.strlen
random_dict = args.random_dict
suffix = args.suffix
value_len = args.value_len


lenmin=args.lenmin
lenmax=args.lenmax

headers = {
    "Content-Type": "application/json",
    "Authorization": "Bearer token-casia-braincog-233",
}

 
dict_type=["all"] if len(args.dict_type)==0 else args.dict_type

if mode=="prompt":       
    if "remote" in model_config and model_config["remote"]:
        tokenizer=AutoTokenizer.from_pretrained(model_config["local_path"])
    else:
        tokenizer=AutoTokenizer.from_pretrained(modelname)
    if tokenizer.chat_template is None:
        tokenizer.chat_template = model_config["chat_template"]
# ======================================================================



def generate_shuffled_dict():
    # 大写字母、小写字母和数字
    all_characters=""
    if "uppercase" in dict_type:
        all_characters+=string.ascii_uppercase
    if "lowercase" in dict_type:
        all_characters+=string.ascii_lowercase
    if "digits" in dict_type:
        all_characters+=string.digits
    if "punctuation" in dict_type:    
        all_characters+=string.punctuation
    if "all" in dict_type:   
        all_characters=string.printable[:-6]#这里直接等于
 
    if random_dict>0:#变成一个随机长短的字典 
        
        def generate_random_string(letter_list):
            length = random.randint(lenmin, lenmax)  # 随机选择字符串长度 1-8
            random_string = ''.join(random.choice(letter_list) for _ in range(length))
            return random_string
        key_list = [generate_random_string(list(all_characters)) for i in range(random_dict*2)] 
        value_list = [generate_random_string(list(all_characters)) for i in range(random_dict*2)] 
    else:
        # 生成key，确保每个key包含所有的大写字母、小写字母和数字
        key_list = list(all_characters)

        # 生成value，确保每个value也包含所有的大写字母、小写字母和数字
        value_list = list(all_characters)

    # 打乱key和value的顺序
    random.shuffle(key_list)
    random.shuffle(value_list)

    # 生成字典
    if value_len>0:
        result_dict = {key_list[i]: value_list[i][:value_len] for i in range(len(key_list))}
    else:
        result_dict = {key_list[i]: value_list[i] for i in range(len(key_list))}
        
    if random_dict>0:
        random_dict_keys = list(result_dict.keys())[:random_dict]
        result_dict = {key: result_dict[key] for key in random_dict_keys}
    # print(result_dict)
    return result_dict

def match_function(decoded_content, valuestring):
 
    # decoded_content_processed=decoded_content[-len(valuestring)-1:-1]
    decoded_content_processed=decoded_content[-len(valuestring):]
    
    valuestring_processed=valuestring
    
    ismatch=decoded_content_processed==valuestring_processed
    
    return ismatch

chardict=generate_shuffled_dict()
# 线程池里的生成函数
def generate(i):
    
    chardict=generate_shuffled_dict()
    map_size = i


    # 上下文例子
    contextkey = " ".join(list(chardict.keys())[:string_len])
    contextkvalue = " ".join(list(chardict.values())[:string_len])

    # 词表变为字符串
    chardictstring = str(dict(list(chardict.items())[:map_size]))

    # query字符串和value字符串

    random_string = [list(chardict.items())[random.randint(0, map_size)]for _ in range(string_len)]
    random_key_string = [m[0] for m in random_string]
    random_value_string = [m[1] for m in random_string]

    keystring = " ".join(random_key_string)

    valuestring = " ".join(random_value_string)
    valuestring2 = "".join(random_value_string)

   
    
    inputstr = (
        """I'm going to give you a string. The string is encoded by the character map.  You need to decode the string according to the character map, output the decoded content.\n\n"""
        + """'""" + contextkey+"""' can be decoded as '""" + contextkvalue + """'.\n\n"""
        + """Character map:\n"""
        + chardictstring
        + """\n\nThe string is: '"""
        + keystring + """'"""
        + """\n\ndirectly output decoded content only. do not output the other content. """
        + """\n\nyou need to follow this Format: Answer: xxx"""
    )
    # print(inputstr)

    messages = [{"role": "user", "content": inputstr},]

    if mode=="prompt":
        prompt=tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
        data_request = {
            "prompt": prompt,
            "model": modelname,
            "max_tokens": 100,
            "temperature": 0.0,
        }
    elif mode=="chat":
        data_request = {
            "messages": messages,
            "model": modelname,
            "max_tokens": 100,
            "temperature": 0.0,
        }

    response = requests.post(url,headers=headers,data=json.dumps(data_request))

    if mode=="chat":
        decoded_content = response.json()["choices"][0]["message"]["content"]
    elif mode=="prompt":
        decoded_content = response.json()["choices"][0]["text"] 
    ismatch=match_function(decoded_content, valuestring)
    
    #======
    ismatch2=match_function(decoded_content, valuestring2)
    ismatch=ismatch+ismatch2
    #========
    box.append((decoded_content, valuestring, ismatch ))

    if ismatch:
        return 1
    else:
        return 0


acc = 0
count = 0
record = [] 
box = []

for i in range(string_len, len(chardict)):

    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = [executor.submit(generate, i) for j in range(times)]
        results = [future.result()for future in concurrent.futures.as_completed(results)]

        acc_local, count_local = 0, 0

        for result in results:
            acc_local += result
            count_local += 1

        acc += acc_local
        count += count_local
        print( 
            "seed:", seed,
            "sum:", "{:.2f}".format(acc / count),
            "sum_local:", "{:.2f}".format(acc_local / count_local),
            "acc_local:", acc_local,
            "count_local:", count_local,
            # end="|"
        )
        if print_box:
            print(box)
        box = []
        record.append(acc_local / count_local)
print(np.array(record))

if random_dict>0:
    string_len4print="rand_"+str(lenmin)+"_"+str(lenmax)
else:
    string_len4print=""
    
np.save("each_size_acc/"+tag_name+"_strlen-"+str(string_len)+str(string_len4print)+"_dict-type-"+str(dict_type)+suffix+"_"+str(seed)+".npy", np.array(record))

 