import pandas as pd
import argparse
import os, json, time
import numpy as np
from openai import OpenAI
import openai

MAX_RETRIES = 5
RETRY_DELAY = 5 

def read_txt_files(directory):
    instructionSet = {}
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".txt"):
                file_path = os.path.join(root, file)
                try:
                    with open(file_path, 'r') as f:
                        content = f.read()
                        instructionSet[file[:-4]] = content
                except Exception as e:
                    print(f"Error reading {file_path}: {e}")
    return instructionSet

def evaluateResponse(client, instructions, problem, evaluationModel):
    if len(problem) == 0:
        return None
    
    retries = 0
    while retries < MAX_RETRIES:
        try:
            response = client.chat.completions.create(
                model = evaluationModel, # 
                messages=[
                    {"role": "system", "content": instructions},
                    {"role": "user", "content": "PROBLEM STATEMENT: \n" + problem},
                ],
                stream = False
            )
            # print(response.choices[0].message.content)
            return response.choices[0].message.content

        except openai.AuthenticationError:
            print("Authentication failed: Invalid API key.")
            break  # Don't retry on bad key

        except (openai.RateLimitError, openai.InternalServerError, openai.APIConnectionError, openai.APITimeoutError) as e:
            print(f"Retryable error occurred: {e}. Retrying in {RETRY_DELAY} seconds...")
            retries += 1
            time.sleep(RETRY_DELAY)

        except Exception as e:
            print(f"Unexpected error: {e}")
            break

    return None

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", help="Experiment Name", required=True)
    parser.add_argument("--model", help="Evaluation Model", required=True)
    # parser.add_argument("--dataset", help="Dataset to run on", choices=["Math", "Logic", "Math_rewritables"], required=True)
    parser.add_argument("--from_row", help="Continue evaluation from which row", type=int, default=0, required=False)
    parser.add_argument("--to_row", help="End evaluation at which row", type=int, default=None, required=False)
    args = parser.parse_args()
    
    args.dataset = "Math_rewritables"

    data = pd.read_csv(f'../data/braingle/braingle_{args.dataset}.csv')
    # responses = pd.read_csv(f'../responses/{args.dataset}/{args.name}/resultsAll.csv')
    evaluationPrompts = read_txt_files("../prompting/brainteaserPrompts/rewriting_experiments")


    os.makedirs(f"../response_evaluation/{args.dataset}/{args.name}", exist_ok=True)

    client = OpenAI(
        api_key= os.getenv("OPENAI_API_KEY")
    )
    print([m.id for m in client.models.list().data])

    # data['Rewritable'] = np.nan
    
    data_iloc = data.iloc[args.from_row:args.to_row] 
    for index, row in data_iloc.iterrows():
        # question = row['Question']
        # dataEntry = data[data['Question'] == question].iloc[0]
        # solution = dataEntry['Answer']
        problem = row['Question']
        

        # print(row.to_dict().keys())
        # print(row)
        print(index)

        if row['Rewritable '] != 1:
            print(row['Rewritable '])
            print("not rewritable")
            continue

        if type(problem) == type("string"):
            
            rewritten = evaluateResponse(client, evaluationPrompts['rewrite'], problem, args.model)
            
            data.at[index, 'Rewritten'] = rewritten

            entry = row.to_dict()

            entry["Rewritten"] = rewritten

            print(rewritten)

            with open(f'../response_evaluation/{args.dataset}/{args.name}/resultsEvaluations_evaluatedby{args.model}.jsonl', 'a') as jsonfile:
                jsonfile.write(json.dumps(entry) + "\n")
        
    data.to_csv(f"../response_evaluation/{args.dataset}/{args.name}-evaluation_from_row{args.from_row}.csv", index=False)

if __name__ == "__main__":
    main()