import os
import datetime
import openai
import random
import argparse
from config import config
from dataset import DataSet
from verifier import MyVerifier
from solver import MySolver
from utils import ( 
                    input_to_text_string,
                    output_to_text_string,
                    output_from_text,
                    get_prompt,
                    parse_gpt_output,
                    run_python_file,
                    get_runtime_error_iterative_prompt,
                    get_verification_error_iterative_prompt,
                    get_timeout_error_iterative_prompt
                 )

openai.api_key = os.environ.get("OPENAI_API_KEY")
mode = 'pal-sat'
random.seed(0)

parser = argparse.ArgumentParser(
                    prog=f'PAL+SAT solver for the {config["problem_name"]} Problem',
                    description=f'Generates a solver script for the {config["problem_name"]} Problem using Solved Examples and GPT')
parser.add_argument('-i', '--input_dataset_path', type=str, required=True, dest="i", help="Path to the File containing Unsolved Input Samples")
parser.add_argument("-o", "--output_dataset_path", type=str, required=False, dest="o", help="Path to the File containing corresponding Solved Samples", default = None)
args = parser.parse_args() 
input_dataset_path = args.i
output_dataset_path = args.o

kwargs = dict()

print(f"Sample Input Sample Path: {input_dataset_path}, Output Sample Path: {output_dataset_path}")
dataset = DataSet()(input_dataset_file=input_dataset_path, output_dataset_file=output_dataset_path ,**kwargs)
num_samples_to_test = min(config["num_samples_to_test"], len(dataset))
print(f"Number of Samples to Test: {num_samples_to_test}")
initial_sample_input = dataset[0]["input"]
initial_sample_output = dataset[0]["output"] if dataset[0]["output"] is not None else MySolver()(initial_sample_input, **kwargs)[0]
print(f"Sample Input:\n{input_to_text_string(initial_sample_input, **kwargs)}")
print(f"Sample Output:\n{output_to_text_string(initial_sample_output, **kwargs)}")
base_prompt = get_prompt(initial_sample_input, initial_sample_output, mode=mode, **kwargs)
print(base_prompt)

### initial context is base prompt
message_history = [{
                'role': 'user',
                'content': base_prompt
                },]


correct_code_generated = False
for prompting_level in range(config["iterative_prompting_depth"]): ### number of levels of iterative prompting
    print(f"------ DEPTH: {prompting_level+1} --------")
    completion = openai.ChatCompletion.create( ### prompt GPT with entire history
        model= config["model"],
        messages=message_history,
        max_tokens=config["max_tokens"],
        temperature=config["temperature"]
    )
    
    assistant_response = completion.choices[0].message.content ### append prompt response to message history
    message_history.append({
        'role': 'assistant',
        'content': assistant_response
    })

    ### parse response and create code file
    code_file_name = f"code-{prompting_level+1}.py"
    with open(code_file_name, "w") as file:
        file.write(parse_gpt_output(assistant_response, depth=prompting_level+1))

    ### test code on a few training samples
    incorrect_found = False
    for idx in range(num_samples_to_test):
        ### get inputs
        input_sample = dataset[idx]["input"]
        
        ### prepare input.txt
        with open("input.txt", "w") as input_file:
            input_file.write(input_to_text_string(input_sample, **kwargs))
        ### prepare output.txt, clear contents if any
        with open("output.txt", "w") as _:
            pass

        ### run generated code on the given input
        run_output = run_python_file(code_file_name, depth=prompting_level+1, timeout=config['timeout'])
        print(idx)
        print(run_output)
        
        ### timeout error
        if "TIMEOUT-ERROR" in run_output:
            print("Timeout Error")
            error = run_output["TIMEOUT-ERROR"]
            incorrect_found = True
            iterative_prompt = get_timeout_error_iterative_prompt(config["timeout_error_feedback_level"], input_sample, error, **kwargs)
            message_history.append({
                'role': 'user',
                'content': iterative_prompt
            })
            break
            
        
        ### runtime error
        if ("RUNTIME-ERROR" in run_output) or ("STD-ERROR" in run_output and len(run_output["STD-ERROR"])):
            print("Runtime Error / Output to Standard Error Stream")
            error = run_output["ERROR"] if ("RUNTIME-ERROR" in run_output) else run_output["STD-ERROR"]
            incorrect_found = True
            iterative_prompt = get_runtime_error_iterative_prompt(config["runtime_error_feedback_level"], input_sample, error, **kwargs)
            message_history.append({
                'role': 'user',
                'content': iterative_prompt
            })
            break
        
        ### try to get output board from output.txt
        output_lines = []
        with open("output.txt", 'r') as f:
            output_lines = f.readlines()
        output_sample = output_from_text(output_lines, **kwargs)

        if output_sample["ERROR"] is not None: ### error while extracting output board
            print(f"Board Parsing Error for Output: {output_lines}")
            error_type = "Output Format Error: Could not Parse Output from output.txt"
            error = output_sample["ERROR"]
            incorrect_found = True
        else:
            ### check output board
            verification_result = MyVerifier()(input_sample, output_sample["OUTPUT"], **kwargs)
            error = verification_result["reason"]
            error_type = "Output Verification Error: Does not Match Expected Output"
            incorrect_found = not verification_result["result"]
            if incorrect_found:
                print("Verification Error")
            
        if incorrect_found:
            print(f"Error In Output")
            text_file_output = '\n'.join(output_lines)
            generated_output = f"output.txt:\n{text_file_output}\nstd-output:\n{run_output['STD-OUTPUT']}"
            iterative_prompt = get_verification_error_iterative_prompt(config["verification_error_feedback_level"], input_sample, MySolver()(input_sample, **kwargs)[0], error_type, error, generated_output, **kwargs)
            message_history.append({
                'role': 'user',
                'content': iterative_prompt
            })
            break
        
        ### use STD-OUTPUT and STD-ERROR
    
    if not incorrect_found:
        correct_code_generated = True
        print(f"Solution at Depth: {prompting_level+1} is Correct")
        break
        


### generate log file
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
with open(f"log-{timestamp}.txt", 'w') as log_file:
    log_file.write(f"Sample Input Sample Path: {input_dataset_path}, Output Sample Path: {output_dataset_path}\n")
    log_file.write(f"Sample Input:\n{input_to_text_string(initial_sample_input, **kwargs)}\n")
    log_file.write(f"Sample Output:\n{output_to_text_string(initial_sample_output, **kwargs)}\n")
    log_file.write(f"Config: {config}" + "\n")
    log_file.write(f"Kwargs: {kwargs}" + "\n")
    log_file.write("-"*50 + "\n")
    for message in message_history:
        log_file.write(message["role"] + "\n")
        log_file.write("-"*50 + "\n")
        log_file.write(message["content"] + "\n")
        log_file.write("\n")
        log_file.write("-"*50)
        log_file.write("\n")
    log_file.write("Correct Solution" if correct_code_generated else "Could not find correct solution, prompting depth limit reached")