from collections import Counter

from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel
import openai
from openai import OpenAI
from dotenv import load_dotenv
import os
import string
import json
import time
import requests
from vllm import LLM, SamplingParams
import math
from statistics import mean
import numpy as np
gpu_count = torch.cuda.device_count()

#
model_path = ""

llm = LLM(model=model_path, tensor_parallel_size=gpu_count, max_logprobs=1000, max_model_len=4096)

tokenizer = AutoTokenizer.from_pretrained(model_path)


def llama(prompt, stop=None):
    sampling_params = SamplingParams(temperature=0.9, stop=stop, max_tokens=20, logprobs=1000, prompt_logprobs=5,
                                     top_p = 0.9)
    outputs = llm.generate(prompt, sampling_params)
    generate = outputs[0].outputs[0].text
    return generate, outputs

def perplexity(logprobs):
    # 假设 logprobs 是每个 token 的 log probability 的列表
    # token 的总数
    N = len(logprobs)
    # 计算困惑度
    avg_logprob = sum(logprobs) / N
    perplexity = math.exp(-avg_logprob)
    return perplexity

def check_string(A, B):
    A_lower = A.lower()
    if isinstance(B, list):  # 判断 B 是否是列表
        B_lower = [item.lower() for item in B]  # 将列表中的每个元素转为小写
        return A_lower in B_lower
    elif isinstance(B, str):  # 判断 B 是否是字符串
        B_lower = B.lower()
        return A_lower == B_lower

def have_common_substring_of_length(str1, str2, length):
    if length <= 0:
        return False
    substrings1 = {str1[i:i + length] for i in range(len(str1) - length + 1)}
    substrings2 = {str2[i:i + length] for i in range(len(str2) - length + 1)}
    return not substrings1.isdisjoint(substrings2)


with open("./prompts/QA.txt", "r", encoding='utf-8') as f:
    task_prompt = f.read()

with open("head_to_tail_dbpedia.json", "r", encoding='utf-8') as f:
    dataset = json.load(f)

dic_list = []
perplexity_list = []
right_perplexity_list = []
wrong_perplexity_list = []
unsure_perplexity_list = []
count = 0
right_count = 0
length = 3
rank_list = []
for d in tqdm(dataset["tail"][0:]):
    count += 1
    question = d[2]
    truth = d[3]
    if isinstance(truth, list):
        truth_first = truth[0]
    else:
        truth_first = str(truth)
    prompt = task_prompt.format(question) + '\xa0'
    # tokens = tokenizer.tokenize(truth_first)
    # prompt = "Do not think of {}.\n".format(truth_first) + prompt
    gen, outputs = llama(prompt, stop="\n")
    if len(gen) == 0:
        gen += " "
    if gen[0].isspace():
        gen = gen[1:]
    # print(outputs[0].outputs[0].logprobs)
    logprob_list = []
    for logprob in outputs[0].outputs[0].logprobs:
        logprob = {key: value.__dict__ for key, value in logprob.items()}
        first_item = next(iter(logprob.items()))
        prob = first_item[1]['logprob']
        logprob_list.append(prob)
    # print(logprob_list)
    p = perplexity(logprob_list)
    perplexity_list.append(p)
    logprobs = outputs[0].outputs[0].logprobs[0]
    logprobs = {key: value.__dict__ for key, value in logprobs.items()}
    value_list = list(logprobs.values())
    flag = False
    skip_num = 0
    for i in value_list:
        ans = i['decoded_token']
        if ans.isspace():
            skip_num += 1
            continue
        if isinstance(truth, list):
            for t in truth:
                flag = have_common_substring_of_length(ans, t, length=length)
                if flag is True:
                    break
        else:
            flag = have_common_substring_of_length(ans, str(truth), length=length)
        if flag:
            rank = i['rank']
            rank = rank - skip_num
            break
    if flag is False:
        rank = 1000

    rank_list.append(rank)

    dic_temp = {}
    dic_temp["question"] = question
    dic_temp["truth"] = truth
    dic_temp["answer"] = gen
    dic_temp["logprobs"] = logprobs
    # print(logprobs)
    dic_list.append(dic_temp)
    # print(outputs[0].outputs[0].logprobs[0])
    if gen == "unsure":
        unsure_perplexity_list.append(p)

    if check_string(gen, truth):
        right_count += 1
        right_perplexity_list.append(p)
    else:
        wrong_perplexity_list.append(p)
    # print("gen:", gen)
    # print("per", p)
    print(truth)
    print(gen)
    print("acc:{}/{}".format(right_count, count))
    # file = open("./result11.txt", "a")
    # file.write(question + '\n')
    # file.write(truth + '\n')
    # file.write("find in top {}".format(find) + '\n')
# 统计列表中各个元素的数量
counter = Counter(rank_list)
counter_dict = dict(counter)

# 打印统计结果
print(counter_dict)

def fill(data):
    # 找到字典中的最小和最大key
    min_key = min(data.keys())
    max_key = max(data.keys())

    # 使用range遍历所有可能的key，如果key不存在于字典中，将其添加并设置为0
    for key in range(min_key, max_key + 1):
        if key not in data:
            data[key] = 0
    return data

data = fill(counter_dict)
sorted_data = dict(sorted(data.items(), key=lambda item: item[0], reverse=False))
keys = list(sorted_data.keys())
values = list(sorted_data.values())
cumulative_values = np.cumsum(values)
print("top 100:  ", cumulative_values[99])
print("top 50", cumulative_values[49])
print("top 10:", cumulative_values[9])
print("top 5:", cumulative_values[4])
print("top 1:", cumulative_values[0])








with open("prompts/dp_logprobs_head_8b.json", 'w', encoding='utf-8') as file:

    json.dump(dic_list, file, ensure_ascii=False, indent=4)



