import os, pdb, time, re
import os.path as osp
import random
import json
 
from argparse import ArgumentParser

import openai
openai_key = ""

time_stamp = time.strftime("%Y-%m-%d_%H-%M")

################# CONFIG #####################
parser = ArgumentParser()

parser.add_argument("--data_path", type=str, default="seed_2022/logs/evaluation/model_10c_20")
parser.add_argument("--output_file", type=str, default="rebuttal/logs/gpt/original_naturalness.json")
parser.add_argument("--error_file", type=str, default="rebuttal/logs/gpt/naturalness_error.json")
parser.add_argument("--model_name", type=str, default="gpt4")

args = parser.parse_args()
print(json.dumps(vars(args), indent=4))

# file path
os.makedirs(osp.dirname(args.output_file), exist_ok=True)
print("output_file: ", args.output_file, "\n")

ZWSP = '\u200b'
ZWNJ = "\u200c"
ZWJ = "\u200d"
IT = "\u2062"
IS = "\u2063"
IP = "\u2064"

openai.api_key = openai_key
temperature = 0.0

def sanitize(sentence):
    sentence = re.sub(r'\\u[0-9A-Fa-f]{4}', '', sentence)
    sentence = sentence.replace(' [WTM]', '')
    return sentence

def read_strings_from_jsons(dir_path):
    all_strings = []

    # Iterate over every file in the directory
    for filename in os.listdir(dir_path):
        if filename.endswith('.json'):
            # Open each json file
            with open(os.path.join(dir_path, filename)) as f:
                data = json.load(f)

                # Assumes that each json file contains a list of strings
                for string in data:
                    all_strings.append(string)

    return all_strings

######################## READ DATASET ############################

pattern = r".* - INFO - __main__ -   original line is: (.*)"
# pattern = r".* - INFO - __main__ -   generated_text: (.*)"

with open(args.data_path, 'r') as file:
    content = file.read()

matches = re.findall(pattern, content)

print(matches[0])

# matches = read_strings_from_jsons(args.data_path)

import pdb; pdb.set_trace()


###################### PROMPT ############################
   

definition = """
Please provide a 1-10 scalar score rating the naturalness of the following text. Note that you should only give a number and no explanation.
"""


def generate_prompt(definition, generated_text):
     
    prompt = definition + "Here is the text: " + sanitize(generated_text)
        
    return prompt

############## MAIN FUNCTION #################### 

final_result = {}
error_dataset = {}

# save parameters
params = vars(args)
# record time
start_time = time.time()

average_score = 0
     
for i in range(len(matches)):
    generated_text = matches[i]

    prompt = generate_prompt(definition, generated_text)
    print(f"prompt: \n{prompt}")

    # run model api and get response
    max_tokens = 400
    max_req_count = 3
    req_success = False
    while not req_success and max_req_count > 0:
        try:
            response = openai.ChatCompletion.create(model="gpt-4",
                                    messages=[{'role':'user','content':prompt}],									
                                    temperature = temperature,
                                    max_tokens=max_tokens)
    
            orginal_answer = response['choices'][0]['message']['content']            
        
            final_result[i] = orginal_answer
            
            print('\n\n')
            print(orginal_answer)
            print('||||||||||||||||||||||||||\n\n')
            req_success = True
        
        except Exception as e:
            print(e)
            print(f"max_req_count: {max_req_count}")
            time.sleep(60)
            if max_req_count > 0:
                max_req_count -= 1
            else:
                if i not in error_dataset:
                    error_dataset[i] = {'error_message': str(e), 'used_prompt': prompt}

    print('>>>>>>>>>>>')
    
    average_score += float(orginal_answer)
    
    if i % 5 == 0:
        end_time = time.time()
        elapsed_time = end_time - start_time

        with open(args.output_file, 'w') as f:
            f.write(json.dumps(final_result, indent=4))
    
    print("-"*70)

average_score = average_score / len(matches)
print('average_score: ', average_score)
final_result['average_score'] = average_score
    
end_time = time.time()
elapsed_time = end_time - start_time

with open(args.output_file, 'w') as f:
    f.write(json.dumps(final_result, indent=4))
    
with open(args.error_file, 'w') as f:
    f.write(json.dumps(error_dataset, indent=4))