import os, pdb, time, re
import os.path as osp
import random
import json
 
from argparse import ArgumentParser

import openai
openai_key = ""

time_stamp = time.strftime("%Y-%m-%d_%H-%M")

################# CONFIG #####################
parser = ArgumentParser()

parser.add_argument("--data_path", type=str, default="seed_2022/logs/evaluation/unwatermarked")
parser.add_argument("--output_file", type=str, default="rebuttal/logs/gpt/unwatermarked.json")
parser.add_argument("--error_file", type=str, default="rebuttal/logs/gpt/error.json")
parser.add_argument("--model_name", type=str, default="gpt4")

args = parser.parse_args()
print(json.dumps(vars(args), indent=4))

# file path
os.makedirs(osp.dirname(args.output_file), exist_ok=True)
print("output_file: ", args.output_file, "\n")

ZWSP = '\u200b'
ZWNJ = "\u200c"
ZWJ = "\u200d"
IT = "\u2062"
IS = "\u2063"
IP = "\u2064"

openai.api_key = openai_key
temperature = 0.0

classes = ['astro-ph', 'cond-mat.mes-hall', 'cond-mat.mtrl-sci', 'cond-mat.str-el', 'cs.CV', 'cs.LG', 'gr-qc', 'hep-ph', 'hep-th', 'quant-ph']

def sanitize(sentence):
    sentence = re.sub(r'\\u[0-9A-Fa-f]{4}', '', sentence)
    sentence = sentence.replace(' [WTM]', '')
    return sentence

def read_strings_from_jsons(dir_path):
    all_strings = []

    # Iterate over every file in the directory
    for filename in os.listdir(dir_path):
        if filename.endswith('.json'):
            # Open each json file
            with open(os.path.join(dir_path, filename)) as f:
                data = json.load(f)

                # Assumes that each json file contains a list of strings
                for string in data:
                    all_strings.append(string)

    return all_strings

######################## READ DATASET ############################

pattern = r".* - INFO - __main__ -   generated_text: (.*)"

with open(args.data_path, 'r') as file:
    content = file.read()

matches = re.findall(pattern, content)

print(matches[0])

# matches = read_strings_from_jsons(args.data_path)

import pdb; pdb.set_trace()


###################### PROMPT ############################
   

definition = """
Given below are 10 categories for texts from ArXiv papers with their descriptions. Please read the descriptions and classify the provided texts to one of the paper categories.
The 10 categories are: hep-th, hep-ph, quant-ph, astro-ph, cs.CV, cs.LG, cond-mat.mes-hall, gr-qc, cond-mat.mtrl-sci, cond-mat.str-el.
hep-th stands for High Energy Physics - Theory. This category includes research papers which are centered on theoretical concepts and mathematical models in high energy physics.
hep-ph stands for High Energy Physics - Phenomenology. This category includes research papers centered on the application of theoretical physics to high energy physics experiments.
quant-ph stands for Quantum Physics. This category includes research papers centered on the theoretical and experimental aspects of the fundamental theory of quantum mechanics.
astro-ph stands for Astrophysics. This category includes research papers centered on the study of the physics of the universe, including the properties and behavior of celestial bodies.
cs.CV stands for Computer Science - Computer Vision and Pattern Recognition. This category includes research papers focused on how computers can be made to gain high-level understanding from digital images or videos.
cs.LG stands for Computer Science - Machine Learning. This category includes research papers focused on the development and implementation of algorithms that allow computers to learn from and make decisions or predictions based on data.
cond-mat.mes-hall stands for Condensed Matter - Mesoscale and Nanoscale Physics. This category includes research papers that focus on the properties and phenomena of physical systems at mesoscopic (intermediate) and nanoscopic scales.
gr-qc stands for General Relativity and Quantum Cosmology. This category includes research papers centered on theoretical and observational aspects of the theory of general relativity and its implications for understanding cosmology at the quantum scale.
cond-mat.mtrl-sci stands for Condensed Matter - Materials Science. This category includes research papers centered on the understanding, description, and development of novel materials from a physics perspective.
cond-mat.str-el stands for Condensed Matter - Strongly Correlated Electrons. This category includes research papers focused on the study of solids and liquids in which interactions among electrons play a dominant role in determining the properties of the material.
Note that you should only include the class in your reply and provide no explanations. 
Please classify the following sentence into one of the 10 categories, however, if you think that the sentence could be classified into multiple categories, you may give up to 3 most likely categories: 
"""


def generate_prompt(definition, generated_text):
     
    prompt = definition + sanitize(generated_text)
        
    return prompt

############## MAIN FUNCTION #################### 

final_result = {}
error_dataset = {}

# save parameters
params = vars(args)
# record time
start_time = time.time()

def get_index(strings, target):
    indices = []
    try:
        words = target.split(', ')
        for word in words:
            if word in strings:
                indices.append(strings.index(word)) 
            else:
                return [-1]
        return indices
    except ValueError:
        return [-1]
     
for i in range(len(matches)):
    generated_text = matches[i]

    prompt = generate_prompt(definition, generated_text)
    print(f"prompt: \n{prompt}")

    # run model api and get response
    max_tokens = 400
    max_req_count = 3
    req_success = False
    while not req_success and max_req_count > 0:
        try:
            response = openai.ChatCompletion.create(model="gpt-4",
                                    messages=[{'role':'user','content':prompt}],									
                                    temperature = temperature,
                                    max_tokens=max_tokens)
    
            orginal_answer = response['choices'][0]['message']['content']   
        
            final_result[i] = get_index(classes, orginal_answer)
            
            print('\n\n')
            print(orginal_answer)
            print('||||||||||||||||||||||||||\n\n')
            req_success = True
        
        except Exception as e:
            print(e)
            print(f"max_req_count: {max_req_count}")
            time.sleep(60)
            if max_req_count > 0:
                max_req_count -= 1
            else:
                if i not in error_dataset:
                    error_dataset[i] = {'error_message': str(e), 'used_prompt': prompt}

    print('>>>>>>>>>>>')
    
    # import pdb; pdb.set_trace()

    if i % 5 == 0:
        end_time = time.time()
        elapsed_time = end_time - start_time

        with open(args.output_file, 'w') as f:
            f.write(json.dumps(final_result, indent=4))
    
    print("-"*70)

    
end_time = time.time()
elapsed_time = end_time - start_time

with open(args.output_file, 'w') as f:
    f.write(json.dumps(final_result, indent=4))
    
with open(args.error_file, 'w') as f:
    f.write(json.dumps(error_dataset, indent=4))