import pickle
from transformers import LlamaTokenizer
from transformers import AutoTokenizer


import pandas as pd
from datasets import Dataset
import datasets
import tqdm
import argparse
import os
import config

import evaluate


import re
# 示例 token IDs（数值类型的 tokens）
# token_ids = [1234, 5678, 91011]  # 这些是示例 IDs，使用实际的 token IDs 替换

import requests


import tqdm
import torch
import json

rouge = evaluate.load('/home/lirui/sar/SAR/src/evaluate-main/metrics/rouge')


def save_json(data,file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, ensure_ascii=False, indent=4)

def extract_last_qa(text):
    pattern = r'Q: (.*?)(?=A:|$)'
    
    matches = re.findall(pattern, text, re.DOTALL)
    
    if matches:
        return matches[-1].strip()
    else:
        return None

def extract_last_question_answer(text):
    pattern = r'Question: (.*?)(?=Answer:|$)'

    matches = re.findall(pattern, text, re.DOTALL)

    if matches:
        return matches[-1].strip()
    else:
        return None

strings_to_filter_on = [
        '.', '\n', 'Q:', 'A:', 'question:', 'answer:', 'Question:', 'Answer:', 'Questions:', 'questions:', 'QUESTION:',
        'ANSWER:'
    ]
def extract_before_question(text):
    pattern = '|'.join(re.escape(s) for s in strings_to_filter_on)
    
    regex = rf'(.*?)(?={pattern})'
    
    match = re.search(regex, text, re.DOTALL)
    
    if match:
        return match.group(1).strip()  
    else:
        return None  

def chat(prompt: str):
    url = 'http://localhost:8000/v1/chat/completions'
    headers = {
        'accept': 'application/json',
        'Content-Type': 'application/json'
    }
    # , api_key: str
    data = {
        "model": "llama3-8b",
        "messages": [{"role": "user", "content": prompt}],  
        # "max_tokens": max_tokens,
        "temperature": 0
    }
    response = requests.post(url, headers=headers, json=data)
    res=response.json()
    # print(res['choices'][0]['message']['content'])
    return res['choices'][0]['message']['content']



def extract_numbers(text):
    # Define the regular expression pattern for numbers (including integers and floats)
    pattern = r'\d+\.\d+|\d+'
    # Find all matches in the text
    matches = re.findall(pattern, text)
    # Convert matches to float
    numbers = [float(match) for match in matches]
    return numbers[0]

def determine_correct(rougel,llama_score):
    rougel_threshold = 0.5 
    llama_score_threshold = 0.5  


    if rougel >= rougel_threshold or llama_score >= llama_score_threshold:
        return 1 
    else:
        return 0  
        

def method_model(response,answer):
    prompt=f"""
    Please judge whether response is correct according to the reference answer, return 1 if yes, 0 if not, do not return anything superfluous: 
    Reference Answer:: {answer}
    response:{response}
    """
    score=chat(prompt)
    print(score)
    score=extract_numbers(score)
    return score

def method_rouge(response,answer):
    results = rouge.compute(predictions=[response], references=[answer])
    # scores = scorer.score(response, answer)
    return results['rougeL']


def judge_correct(response,answer):
    # print("#"*20)
    rougel=method_rouge(response,answer)
    llama_score=method_model(response,answer)
    correct=determine_correct(rougel,llama_score)
    return correct


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate a dataset using a text-generation model.')
    parser.add_argument('--dataset', type=str, help='The name of the dataset')
    parser.add_argument('--model', type=str, help='The name of the dataset')
    parser.add_argument('--generation_path', type=str, help='The name of the dataset')
    args = parser.parse_args()
    
    
    new_data=[]
    dataset=args.dataset
    model=args.model
    generation_path=args.generation_path
    # if dataset=='coqa'
    with open(generation_path, 'rb') as file:
        data = pickle.load(file)
    if 'opt' in model:
        tokenizer_path='/mnt/data/model/opt/opt-2.7b'
    if 'llama' in model:
        tokenizer_path='/mnt/data/model/LLM-Research/Meta-Llama-3-8B-Instruct'
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, legacy=False)
    for i in tqdm.tqdm(range(len(data))):
        token_ids=data[i]['prompt']
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        text = tokenizer.convert_tokens_to_string(tokens)
        print(text)
        if dataset=='coqa':
            text = text
        else:
            text = extract_last_question_answer(text)
            # text = extract_last_qa(text)
        print(text)
        
        # print(text)
        response=data[i]['most_likely_generation']
        answer=data[i]['answer'][0]
        correct=judge_correct(response,answer)    
        new_data.append({'text':text,'correct':correct})
    # save_json(new_data,'path/uncertainty/semantic_uncertainty-main/code/data/cola_opt_2.7b_train/opt_2.7b_train.json')
    save_json(new_data,f'path/uncertainty/my_uncertainty/dataset_train_ready/{dataset}_{model}/eval.json')
