import random
import time
from numpy.typing import NDArray
from typing import List
import numpy as np
import os
from utils.scorer.bert_score import calculate_score
# from bert_score import calculate_score
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from sentence_transformers import SentenceTransformer, util

def get_bert_score(cands: List[str], refs: List[str], model_type="distilbert-base-uncased", auto=True) -> NDArray:
    res = calculate_score(cands, refs, model_type=model_type, verbose=True, auto=auto)
    return res["f1"]

def get_bleu_score(ref_sentences: List[str], cand_sentences: List[str]) -> NDArray:
    bleu_score_list = []
    smoothie = SmoothingFunction().method4
    for ref_sent, cand_sent in zip(ref_sentences, cand_sentences):
        assert isinstance(ref_sent, str), "ref_sentences must be a list of strings"
        assert isinstance(cand_sent, str), "cand_sentences must be a list of strings"
        
        ref_words = ref_sent.lower().split()
        cand_words = cand_sent.lower().split()
        
        bleu_score = sentence_bleu([ref_words], cand_words, smoothing_function=smoothie)
        bleu_score_list.append(bleu_score)
    
    return np.array(bleu_score_list)

def get_cosine_similarity(cands: List[str], refs: List[str], model_name='all-mpnet-base-v2', auto=True) -> NDArray:
    # https://www.sbert.net/
    # load pretrained model
    if not auto:
        model_name = f'{os.getcwd()}/.cache/all-mpnet-base-v2'
    model = SentenceTransformer(model_name)
    cos_score_list = []
    for cand, ref in zip(cands, refs):
        assert isinstance(cand, str), "cands must be a list of strings"
        assert isinstance(ref, str), "refs must be a list of strings"

        embeddings1 = model.encode(cand, convert_to_tensor=True)
        embeddings2 = model.encode(ref, convert_to_tensor=True)
        
        cos_score = util.pytorch_cos_sim(embeddings1, embeddings2)
        cos_score_list.append(cos_score.item())
    
    return np.array(cos_score_list)

def get_diag_cosine_similarity(cands: List[str], refs: List[str], model_name='all-mpnet-base-v2', auto=True) -> NDArray:
    # https://www.sbert.net/
    # load pretrained model
    if not auto:
        model_name = f'{os.getcwd()}/.cache/all-mpnet-base-v2'
    model = SentenceTransformer(model_name)
    
    embeddings1 = model.encode(cands, convert_to_tensor=True)
    embeddings2 = model.encode(refs, convert_to_tensor=True)

    cos_score_matrix = util.pairwise_cos_sim(embeddings1, embeddings2).cpu()
    # cos_score_list = [cos_score_matrix[i][i].item() for i in range(len(cands))]
    
    return np.array(cos_score_matrix)

def get_batch_cosine_similarity(cands: str, refs: List[str], model_name='all-mpnet-base-v2', auto=True) -> NDArray:
    # https://www.sbert.net/
    # load pretrained model
    if not auto:
        model_name = f'{os.getcwd()}/.cache/{model_name}'
    model = SentenceTransformer(model_name)
    embeddings1 = model.encode([cands], convert_to_tensor=True)
    embeddings2 = model.encode(refs, convert_to_tensor=True)
    cos_score_list = util.pytorch_cos_sim(embeddings1, embeddings2).squeeze(0).cpu().numpy()
    
    return cos_score_list

# new function by gao jun
def testOpen(system="",prompt=""):
    import openai
    request_model = "gpt-4"
    # request_model = "gpt-3.5-turbo"
    # openai.api_key = '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'   #pcl-5
    # openai.api_key = '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'
    keys = [
        "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"
    ]

    openai.api_key = random.choice(keys)

    completion = openai.ChatCompletion.create(
        model = request_model,
        # messages=[{"role": "user", "content": "Tell the world about the ChatGPT API in the style of a pirate."}]
        messages = [{"role": "system", "content": system},{"role": "user", "content": prompt}] if system else [{"role": "user", "content": prompt}],
        temperature = 0
    )
    # print(completion)
    # time.sleep(3)
    return completion

def get_diag_gpt4_similarity_exp1(cands: str, refs: List[str], model_name='all-mpnet-base-v2', auto=True) -> NDArray:

    sys_p1 = '''
    Please calculate the similarity in values（'Power', 'Spirituality', 'Benevolence', 'Tradition', 'Self-Direction', 'Achievement', 'Stimulation', 'Security', 'Conformity', 'Hedonism'） between the given sentence and the following sentences and value, with scores ranging from 0 to 1'''
    sys_p2 = '''Final output format as a dictionary
    '''
    value_list = ['Power', 'Spirituality', 'Benevolence', 'Tradition', 'Self-Direction', 'Achievement', 'Stimulation', 'Security', 'Conformity', 'Hedonism'] 
    score_list = []
    for count, q_a in enumerate(refs):
        sys_p_final = sys_p1 + '\n' + q_a + "\n\n" + sys_p2
        user_p_final = cands
        print(sys_p_final)
        print(user_p_final)
        flag = True
        max_retries = 3  # 设置最大重试次数
        retry_count = 0  # 初始化重试计数器
        while flag and retry_count < max_retries:
            try:
                r = testOpen(system=sys_p_final,prompt=user_p_final)["choices"][0]["message"]["content"]
            except Exception as e:
                print("发生了错误:", e)
                retry_count += 1  # 增加重试计数
            else:
                flag = False
        
        print("gpt4请求成功，输出为：")
        print(r)
        try:
            r_dict = eval(r)
        except:
            print("解析失败")
            r_dict = {"fail": "fail"}
        
        try:
            score_list.append(int(r_dict[value_list[count]]))
            print("字段正确，已添加")
        except:
            score_list.append(0)
            print("fail to parse!!!!!!!!!!!!!!!!!!!")
        
    return np.array(score_list)

def get_diag_gpt4_similarity_exp2(cands: List[str], refs: List[str], model_name='all-mpnet-base-v2', auto=True) -> NDArray:
    sys_p1 = '''
    Given a standard analysis, please evaluate the similarity of how reasonable between the input analysis and the standard analysis with scores ranging from 0 to 1.

    standard analysis：'''
#     sys_p1 = '''You should act as a human labeler and give me a score between 0 to 1 showing the reasoning level between the 2 sentences [1] and [2]. You should notice that the reasoning ability behind the sentence is the key even the using word / topics are different. Here are the 2 sentences:
# '''
    sys_p2 = '''Final output format as a dictionary
    '''

    score_matrix = np.zeros(len(cands))

    for ct, cand in enumerate(cands):

        # sys_p_final = sys_p1 + '\n[1]' + refs[ct] + "\n[2]" + sys_p2 + '\n'
        sys_p_final = sys_p1 + '\n' + refs[ct] + "\n\n" + sys_p2
        user_p_final = cands[ct]
        print(sys_p_final)
        print(user_p_final)
        flag = True
        while flag:
            try:
                r = testOpen(system=sys_p_final,prompt=user_p_final)["choices"][0]["message"]["content"]
            except:
                print("fail to query gpt")
            else:
                flag = False
        print(r)
        try:
            r_dict = eval(r)
        except:
            r_dict = {"fail": "fail"}
        
        try:
            k = list(r_dict.keys())[0]
            convt = float(r_dict[k])
        except:
            convt = 0

        score_matrix[ct] = convt
    
    return np.array(score_matrix)

if __name__ == "__main__":
    # the cands and refs are lists of strings
    cands = ["I like NLP", "I like to skate", "This is a dilicious food"]
    refs  = ["I like NLP", "I want to go to the beihai park for skating", "Eating it makes you happy"]
    # bert_score_res = get_bert_score(cands, refs, auto=False)
    score_res = get_cosine_similarity(cands, refs, auto=False)
    print(score_res)
