import json
import tqdm 
import re


import pandas as pd
from datasets import Dataset
import datasets
import tqdm
import argparse
import os
import config
import re

def extract_last_qa(text):

    pattern = r'Q:(.*?)(?=A:|$)'
    
    matches = re.findall(pattern, text, re.DOTALL)

    if matches:
        return matches[-1].strip()
    else:
        return text.replace('Give me the briefest possible answer.','').strip()
def extract_last_question_answer(text):

    pattern = r'Question: (.*?)(?=Answer:|$)'

    matches = re.findall(pattern, text, re.DOTALL)

    if matches:
        return matches[-1].strip()
    else:
        return text.replace('Give me the briefest possible answer.','').strip()


def determine_correct(rougel,llama_score):

    rougel_threshold = 0.5  
    llama_score_threshold = 0.5  

    if rougel >= rougel_threshold or llama_score >= llama_score_threshold:
        return 1  
    else:
        return 0  
        
def coqa_process(model):
    path_score=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen_score.json'
    path_save=f'path/uncertainty/my_uncertainty/dataset_train_ready_{model}/{dataset}'
    if not os.path.exists(path_save):
        os.makedirs(path_save)
    path_save='path/uncertainty/my_uncertainty/dataset_train_ready/coqa/train.json'
    new_data=[]
    with open(path_score, 'r') as f:
        data = json.load(f)
    for item in tqdm.tqdm(data):
        # text=extract_last_qa(item['text'])
        text=item['text']
        # print(text)
        rougel=item['rougel']
        llama_score=item['llama_score']
        correct=determine_correct(rougel,llama_score)
        new_data.append({"text":text,"correct":correct})
    with open(path_save,'w') as f:
        json.dump(new_data,f,indent=4)
    


def sciqa_process(model):
    dataset='sciqa'
    path_score=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen_score.json'
    path_save=f'path/uncertainty/my_uncertainty/dataset_train_ready_{model}/{dataset}'
    if not os.path.exists(path_save):
        os.makedirs(path_save)
    path_save=f'path/uncertainty/my_uncertainty/dataset_train_ready/{dataset}/train.json'
    new_data=[]
    with open(path_score, 'r') as f:
        data = json.load(f)
    for item in tqdm.tqdm(data):
        text=extract_last_question_answer(item['text'])
        text=extract_last_qa(item['text'])
        print(text)
        rougel=item['rougel']
        llama_score=item['llama_score']
        correct=determine_correct(rougel,llama_score)
        new_data.append({"text":text,"correct":correct})
    with open(path_save,'w') as f:
        json.dump(new_data,f,indent=4)

def process(dataset,model):
    path_score=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen_score.json'
    path_save=f'path/uncertainty/my_uncertainty/dataset_train_ready/{dataset}_{model}'
    if not os.path.exists(path_save):
        os.makedirs(path_save)
    path_save=f'path/uncertainty/my_uncertainty/dataset_train_ready/{dataset}_{model}/train.json'
    new_data=[]
    with open(path_score, 'r') as f:
        data = json.load(f)
    for item in tqdm.tqdm(data):
        text=extract_last_qa(item['text'])
        print(text)
        rougel=item['rougel']
        llama_score=item['llama_score']
        correct=determine_correct(rougel,llama_score)
        new_data.append({"text":text,"correct":correct})
    with open(path_save,'w') as f:
        json.dump(new_data,f,indent=4)

def medmcqa_process(model):
    dataset="medmcqa"
    path_score=f'path/uncertainty/my_uncertainty/dataset_process/{dataset}_dataset/{model}_train_gen_score.json'
    path_save=f'path/uncertainty/my_uncertainty/dataset_train_ready/{dataset}'
    if not os.path.exists(path_save):
        os.makedirs(path_save)
    path_save=f'path/uncertainty/my_uncertainty/dataset_train_ready/{dataset}/train.json'
    new_data=[]
    with open(path_score, 'r') as f:
        data = json.load(f)
    for item in tqdm.tqdm(data):
        text=extract_last_qa(item['text'])
        print(text)
        correct=item['correct']
        new_data.append({"text":text,"correct":correct})
    with open(path_save,'w') as f:
        json.dump(new_data,f,indent=4)

    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate a dataset using a text-generation model.')
    parser.add_argument('--dataset', type=str, help='The name of the dataset')
    parser.add_argument('--model', type=str, help='The name of the dataset')
    args = parser.parse_args()
    dataset=args.dataset
    model=args.model
    process(dataset,model)

    # if args.dataset == 'coqa':
    #     coqa_process(model)
    # elif args.dataset == 'sciqa':
    #     sciqa_process(model)
    #     # judge_sciqa(args.method)
    # elif args.dataset == 'triviaqa':
    #     triviaqa_process(model)
    # elif args.dataset == 'medmcqa':
    #     medmcqa_process(model)


        # judge_triviaqa(args.method)
#python path/uncertainty/my_uncertainty/pipeline/003_correct.py --dataset coqa --model llama3
#python path/uncertainty/my_uncertainty/pipeline/003_correct.py --dataset sciqa --model llama3
#python path/uncertainty/my_uncertainty/pipeline/003_correct.py --dataset triviaqa --model llama3


#python path/uncertainty/my_uncertainty/pipeline/003_correct.py --dataset nq --model opt_2.7b


#python path/uncertainty/my_uncertainty/pipeline/003_correct.py --dataset medmcqa --model opt_2.7b
