import os, sys, inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
import re
import json
import numpy as np
import copy
import tqdm
import argparse
from src.data_util import load_data
from src.common import completion_with_backoff,completion_with_mistralC,completion_with_backoff_n

parser = argparse.ArgumentParser()
parser.add_argument("--dataset_name", type = str, required = True)
parser.add_argument("--output_path", type = str, required = True)
parser.add_argument("--sample_n", type = int, required = True)

args = parser.parse_args()

def update_prompt_ambigqa(filepath, question):
    with open(filepath, 'r', encoding='utf-8') as f:
        prompt = f.read()
    
    # Remplace la ligne "Question: {question}" par la question réelle
    updated_prompt = prompt.replace('Question:\n{question}', f'Question:\n{question}')

    return updated_prompt

def update_prompt_ambiginst(filepath, instruction,input):
    with open(filepath, 'r', encoding='utf-8') as f:
        prompt = f.read()
    
    # Remplace la ligne "Question: {question}" par la question réelle
    updated_prompt = prompt.replace('{instruction}', instruction).replace('{input}', input)
    return updated_prompt

def extract_score(ans):
    match = re.search(r"SCORE\s*:\s*([0-9]*\.?[0-9]+)", ans)
    if match:
        return float(match.group(1))
    else:
        raise ValueError("No score found in the response.")

def main(args):
    sample_n=args.sample_n
    output_path = args.output_path
    save_dir = os.path.dirname(output_path)
    if not os.path.exists(save_dir): os.makedirs(save_dir)
    print("save logs to ", output_path)
    model_index = 'gpt-4o-mini-2024-07-18'

    test_data = load_data(args.dataset_name)

    all_results = []
    for idx in tqdm.tqdm(range(len(test_data))):
        case = test_data[idx]
        max_tokens=512
        if args.dataset_name == "ambigqa" :
            question=case["question"]
            prompt_full=update_prompt_ambigqa("lib_prompt/ask4conf/ask_ambigqa_user.txt", question)
        elif args.dataset_name == "ambig_inst" :
            orig_instruction=case["orig_instruction"]
            input=case["input"]
            prompt_full=update_prompt_ambiginst("lib_prompt/ask4conf/ask_ambiginst_user.txt", orig_instruction,input)
        else :
            raise FileNotFoundError("Dataset not found")
        messages=[
            {"role": "user", "content": prompt_full},
        ]
        response = completion_with_backoff(
            model=model_index,
            messages=messages,
            temperature=1,
            max_tokens=max_tokens,
            n=sample_n,
            )
        ans_model_list = []
        for sample_id in range(sample_n):
            ans_model = response['choices'][sample_id]['message']['content']
            try :
                extraction = extract_score(ans_model)
            except :
                extraction=0.5
            ans_model_list.append(extraction)

        result = copy.deepcopy(case)
        if args.dataset_name == 'ambig_inst':
            result['orig_inst'] = case['orig_instruction']            
        result['score'] = ans_model_list
        print(ans_model_list)

        all_results.append(result)


    if output_path.endswith(".txt"):
        json_path = output_path.replace('.txt','.json')
    else:
        json_path = output_path
    with open(json_path, 'w', encoding = 'utf-8') as f:
        json.dump(all_results, f, indent = 4)

if __name__ == '__main__':
    main(args)

