import requests
import json
import pandas as pd
from datasets import Dataset
import datasets
import tqdm
import argparse
import os
import config
import re
import evaluate
# from rouge_score import rouge_scorer

from sentence_transformers import SentenceTransformer
# scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

rouge = evaluate.load('/home/lirui/sar/SAR/src/evaluate-main/metrics/rouge')


def chat(prompt: str):
    url = 'http://localhost:8000/v1/chat/completions'
    headers = {
        'accept': 'application/json',
        'Content-Type': 'application/json'
    }
    # , api_key: str
    data = {
        "model": "llama3-8b",
        "messages": [{"role": "user", "content": prompt}],  
        # "max_tokens": max_tokens,
        "temperature": 0
    }
    response = requests.post(url, headers=headers, json=data)
    res=response.json()
    # print(res['choices'][0]['message']['content'])
    return res['choices'][0]['message']['content']




def method_rouge(response,answer):
    results = rouge.compute(predictions=[response], references=[answer])
    # scores = scorer.score(response, answer)
    return results['rougeL']

def method_similarity(response,answer):
    # Load the pre-trained model
    embedding1 = model.encode(response, convert_to_tensor=True)
    embedding2 = model.encode(answer, convert_to_tensor=True)
    # Compute cosine similarity
    similarity = util.pytorch_cos_sim(embedding1, embedding2)
    return similarity.item()


def extract_numbers(text):
    # Define the regular expression pattern for numbers (including integers and floats)
    pattern = r'\d+\.\d+|\d+'
    # Find all matches in the text
    matches = re.findall(pattern, text)
    # Convert matches to float
    numbers = [float(match) for match in matches]
    return numbers[0]


def method_model(response,answer):
    prompt=f"""
    Please judge whether response is correct according to the reference answer, return 1 if yes, 0 if not, do not return anything superfluous: 
    Reference Answer: {answer}
    response:{response}
    """
    score=chat(prompt)
    print(score)
    score=extract_numbers(score)
    return score
def method_match(response,answer):
    if response.lower() in answer.lower():
        return 1
    else:
        return 0


strings_to_filter_on = [
        '.', '\n', 'Q:', 'A:', 'question:', 'answer:', 'Question:', 'Answer:', 'Questions:', 'questions:', 'QUESTION:',
        'ANSWER:'
    ]
def extract_before_question(text):
 
    pattern = '|'.join(re.escape(s) for s in strings_to_filter_on)

    regex = rf'(.*?)(?={pattern})'

    match = re.search(regex, text, re.DOTALL)

    if match:
        return match.group(1).strip() 
    else:
        return None 

def ensure_file_exists(file_path, source_file_path):
    if not os.path.exists(file_path):
        with open(source_file_path, 'r') as f:
            data = json.load(f)
        with open(file_path,'w') as f:
            json.dump(data,f,indent=4)

def judge_coqa(dataset,model):
    path='path/uncertainty/my_uncertainty/dataset_process/coqa_dataset/train_gen.json'
    path_score='path/uncertainty/my_uncertainty/dataset_process/coqa_dataset/train_gen_score.json'
    ensure_file_exists(path_score,path)
    with open(path_score, 'r') as f:
        data = json.load(f)
    if 'clean_response' not in data[0]:
        for item in tqdm.tqdm(data):
            key=f"{model}_response"
            item['clean_response']=extract_before_question(item[key])
            if item['clean_response'] is None:
                item['clean_response']=item[key]
            print(item['clean_response'])

    for item in tqdm.tqdm(data):
        item['rougel']=method_rouge(item['clean_response'],item['answer']['normalized_value'])
        item['llama_score']=method_model(item['clean_response'],item['answer']['normalized_value'])
    with open(path_score, 'w') as f:
        json.dump(data, f, indent=4)
        

def judge_sciqa(dataset,model):
    
    path=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen.json'
    path_score=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen_score.json'
    ensure_file_exists(path_score,path)

    with open(path_score, 'r') as f:
        data = json.load(f)
    if 'clean_response' not in data[0]:
        key=f"{model}_response"
        for item in tqdm.tqdm(data):
            item['clean_response']=extract_before_question(item[key])
            if item['clean_response'] is None:
                item['clean_response']=item[key]
            print(item['clean_response'])
    for item in tqdm.tqdm(data):
        item['rougel']=method_rouge(item['clean_response'],item['answer']['normalized_value'])
        item['llama_score']=method_model(item['clean_response'],item['answer']['normalized_value'])
    with open(path_score, 'w') as f:
        json.dump(data, f, indent=4)


def judge_triviaqa(dataset,model):
    path=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen.json'
    path_score=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen_score.json'
    ensure_file_exists(path_score,path)

    with open(path_score, 'r') as f:
        data = json.load(f)
    if 'clean_response' not in data[0]:
        key=f"{model}_response"
        for item in tqdm.tqdm(data):
            item['clean_response']=extract_before_question(item[key])
            if item['clean_response'] is None:
                item['clean_response']=item[key]
            print(item['clean_response'])
    

    for item in tqdm.tqdm(data):
        item['rougel']=method_rouge(item['clean_response'],item['answer']['normalized_value'])
        item['llama_score']=method_model(item['clean_response'],item['answer']['normalized_value'])

    with open(path_score, 'w') as f:
        json.dump(data, f, indent=4)

def judge_medmcqa(dataset,model):
    path=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen.json'
    path_score=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen_score.json'
    ensure_file_exists(path_score,path)
    with open(path_score, 'r') as f:
        data = json.load(f)
    if 'clean_response' not in data[0]:
        key=f"{model}_response"
        for item in tqdm.tqdm(data):
            item['clean_response']=extract_before_question(item[key])
            if item['clean_response'] is None:
                item['clean_response']=item[key]
            print(item['clean_response'])
    for item in tqdm.tqdm(data):
        item['rougel']=method_rouge(item['clean_response'],item['answer'])
        item['llama_score']=method_model(item['clean_response'],item['answer'])
    with open(path_score, 'w') as f:
        json.dump(data, f, indent=4)


def judge_nq(dataset,model):
    path=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen.json'
    path_score=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen_score.json'
    ensure_file_exists(path_score,path)
    with open(path_score, 'r') as f:
        data = json.load(f)
    if 'clean_response' not in data[0]:
        key=f"{model}_response"
        for item in tqdm.tqdm(data):
            item['clean_response']=extract_before_question(item[key])
            if item['clean_response'] is None:
                item['clean_response']=item[key]
            print(item['clean_response'])

    # if method=='rouge':
    #     for item in tqdm.tqdm(data):
    #         item['rougel']=method_rouge(item['clean_response'],item['answer']['normalized_value'])
    #         print(item['rougel'])
    # elif method=='similarity':
    #     for item in tqdm.tqdm(data):
    #         item['similarity']=method_similarity(item['clean_response'],item['answer']['normalized_value'])
    # elif method=='model':
    #     for item in tqdm.tqdm(data):
    #         item['llama_score']=method_model(item['clean_response'],item['answer']['normalized_value'])

    for item in tqdm.tqdm(data):
        item['rougel']=method_rouge(item['clean_response'],item['answer'])
        item['llama_score']=method_model(item['clean_response'],item['answer'])


    with open(path_score, 'w') as f:
        json.dump(data, f, indent=4)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate a dataset using a text-generation model.')
    parser.add_argument('--dataset', type=str, help='The name of the dataset')
    # parser.add_argument('--method', type=str, help='')
    parser.add_argument('--model', type=str, help='')
    args = parser.parse_args()
    # if args.method == 'similarity':
    #     model = SentenceTransformer('/mnt/data/model/sbert/all-MiniLM-L6-v2')

    if args.dataset == 'coqa':
        judge_coqa(args.dataset, args.model)
    elif args.dataset == 'sciqa':
        judge_sciqa(args.dataset, args.model)
    elif args.dataset == 'triviaqa':
        judge_triviaqa(args.dataset, args.model)
    elif args.dataset == 'medmcqa':
        judge_medmcqa(args.dataset, args.model)
    elif args.dataset == 'nq':
        judge_nq(args.dataset, args.model)

#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset triviaqa --method rouge
#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset triviaqa --method similarity
#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset triviaqa --method model

#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset sciqa --method rouge --model llama3
#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset sciqa --method model --model llama3

#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset sciqa --method similarity


#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset coqa --method rouge
#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset coqa --method similarity
#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset coqa --method model


#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset  --method rouge
#CUDA_ADDRESS=0 python path/uncertainty/my_uncertainty/pipeline/002_judge.py --dataset coqa --method model