import sys
sys.path.append("../")
import argparse
import json
import os

import shortuuid
import logging
from openai_concurrent import OpenAIChatCompletionConcurrent
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import re
import ast

def parse(review, subquestion_index_dict):
    try:
        match = re.findall(r'{[^}]+}', review)
        if len(match)>0:
            
            dictionary_part = match[-1].replace("\n", "").replace('_', " ").lower()
            lines = ast.literal_eval(dictionary_part)
        #     # for key, value in lines.items():
        #         # if value not in ["yes", "no"]:
        #         #     lines[key] = ""
                
        #     score_dict = {}
        #     for key, value in lines.items():
        #         key = int(key)
        #         if 1<=key and key <=5:
        #             score_dict[subquestion_index_dict[key]]=value
        #     return score_dict
        # else:
        #     return {}
            score_dict = {}
            for key, value in lines.items():
                if value == 'na':
                    lines[key] = 'N/A'
                elif value == 'n/a':
                    lines[key] = 'N/A'
                elif value == 'not applicable':
                    lines[key]= 'N/A'
                key = int(key)
                if 1<=key and key <=5:
                    score_dict[subquestion_index_dict[key]]=value
            return score_dict
        else:
            return {}


    except Exception as e:
        logger.error(f'{e}\nContent: {review}\n'
                     'You must manually fix the score pair.')
        return []



def gen_prompt(reviewer_jsons, prompt_jsons, response, item, subquestion_index_list):
    reviewer_idx = 0
    prompt_id = reviewer_jsons[reviewer_idx]['prompt_id']
    prompt_json = prompt_jsons[prompt_id-1]
    assert prompt_json['prompt_id'] == prompt_id

    sys_prompt = prompt_json['system_prompt']
    prompt_template = prompt_json['prompt_template']
    defaults = prompt_json['defaults']


    cnt=0
    subquestions = ""
    subquestion_index_dict={}
    for subquestion in item["metrics"].keys():
        cnt+=1
        subquestions+=f"\n{cnt}: {subquestion}"
        subquestion_index_dict[cnt] = subquestion 
    subquestion_index_list.append(subquestion_index_dict)
    prompt = prompt_template.format(question=item["text"],response=response, subquestions=subquestions, sample_answer=item["answer"], **defaults)
    return sys_prompt, prompt


def get_json_list(file_path):
    file_path = os.path.expanduser(file_path)
    file_extension = file_path.split('.')[-1]
    if file_extension=="jsonl":
        with open(file_path, 'r') as f:
            json_list = []
            for line in f:
                json_list.append(json.loads(line))
            return json_list
    else:
        with open(file_path, 'r') as f:
            return json.load(f)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-k', '--key-file', default='../openai_info/api_info.json')
    # parser.add_argument('-q', '--question-file', default='../evaluation_set/flask_hard_atomic.jsonl')
    parser.add_argument('-q', '--question-file', default='../metadata_annotation/subquestion_generation/outputs/flask_question_annotation_ver2.jsonl')
    parser.add_argument('-a', '--answer-file', default='../model_output/outputs_hard/chatgpt_hard.jsonl')
    parser.add_argument('-p', '--prompt-file', default='src/prompt.jsonl')
    parser.add_argument('-r', '--reviewer-file', default='src/reviewer.jsonl')
    parser.add_argument('-o', '--output-review-file', default='outputs/chatgpt_review_hard_binary_gt_ver2_prompt_ver1_temp07.jsonl')
    parser.add_argument('-e', '--output-error-file', default='outputs/chatgpt_review_hard_error.jsonl')
    parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
    args = parser.parse_args()

    key_jsons = get_json_list(args.key_file)
    question_jsons = get_json_list(args.question_file)
    answer_jsons = get_json_list(args.answer_file) 
    reviewer_jsons = get_json_list(args.reviewer_file)
    prompt_jsons = get_json_list(args.prompt_file)

    handles = []
    review_jsons = []
    total_len = len(question_jsons)
    question_idx_list = list(range(total_len))
    subquestion_index_list = []
    question_copy = []
    answer_copy = []

    requests = []
    for i in question_idx_list:
        # if i>2:
        #     continue
        for row in answer_jsons:
            if row.get('question_id') == question_jsons[i]['question_id']:
                answer_elem = row
                break
        answer_copy.append(answer_elem)
        assert answer_copy[i]['question_id'] == question_jsons[i]['question_id']
        question_copy.append(question_jsons[i])
        sys_prompt, prompt = gen_prompt(reviewer_jsons, prompt_jsons,answer_copy[i]["text"], question_jsons[i],subquestion_index_list )
        print(prompt)
        review_id = shortuuid.uuid()
        review_jsons.append({
            'review_id': review_id,
            'question_id': question_jsons[i]['question_id'],
            'metadata': {},
        })
        requests.append(
            {
                'review_id': review_id,
                'question_id': question_jsons[i]['question_id'],
                'metadata': {},
                'request': {
                    "model": "gpt-4-0613",
                    "messages":[
                        {
                            'role': 'system',
                            'content': sys_prompt
                        },
                        {
                            'role': 'user',
                            'content': prompt,
                        }
                    ]
                },
                # setting temperature 0 for reproducibility
                # "temperature": 0,
                "temperature": 0.0,
                "max_tokens": args.max_tokens
            }
        )

    openai_concurrent = OpenAIChatCompletionConcurrent(api_keys=key_jsons["api_keys"], requests_per_minute=60, expected_response_seconds=5)
    responses, fails = openai_concurrent.create_many(requests)

    reviews = [response['response']['choices'][0]['message']['content'] for response in responses]
    total_tokens = [response['response']['usage']['total_tokens'] for response in responses]
    print("total_token:", sum(total_tokens))

    output_directory = os.path.dirname(args.output_error_file)

    # Check if the directory exists, if not, create it
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    delete_index = []
    if len(fails)>0:
        with open(f'{args.output_error_file}', 'w') as output_error_file:
            try:
                for idx, fail in enumerate(fails):
                    print("fail:", fail)
                    for index, item in enumerate(question_copy):
                        if int(item.get("question_id")) == int(fail['question_id']):
                            delete_elem_idx = index 
                    delete_index.append(delete_elem_idx)
                    output_error_file.write(json.dumps(question_copy[delete_elem_idx]) + '\n')
            except: 
                print("@@@", delete_index)
                delete_index=[]
    
    question_copy = [item for index, item in enumerate(question_copy) if index not in delete_index]

    with open(f'{args.output_review_file}', 'a') as output_review_file:
        for idx, review in enumerate(reviews):
            scores = parse(review, subquestion_index_list[idx])
            review_jsons[idx] = question_copy[idx]
            for row in answer_jsons:
                if row.get('question_id') == question_copy[idx]['question_id']:
                    review_jsons[idx]['target_txt'] = row["text"]
            review_jsons[idx]['review'] = review
            review_jsons[idx]['score'] = scores
            review_jsons[idx]['total_tokens_step4'] = total_tokens[idx]
            try:
                output_review_file.write(json.dumps(review_jsons[idx]) + '\n')
            except Exception as e:
                output_review_file.write('\n')
                print(review_jsons[idx]['question_id'])
        output_review_file.close()

    with open(f'{args.output_review_file}', 'r') as output_read_file:
        lines = output_read_file.readlines()
        output_read_file.close()
    json_objects = [json.loads(line) for line in lines]
    sorted_objects = sorted(json_objects, key=lambda obj: obj.get('question_id'))

    with open(f'{args.output_review_file}', 'w') as output_write_file:
        for obj in sorted_objects:
            try: 
                output_write_file.write(json.dumps(obj) + '\n')
            except Exception as e:
                output_write_file.write('\n')
                print(obj['question_id'])
