import os
import argparse
import datetime
from config import config
from dataset import DataSet
from verifier import MyVerifier
from utils import input_to_text_string, output_from_text, run_python_file

parser = argparse.ArgumentParser(prog='Verifies the Programs Generated in a single Run of the PaL-SAT algorithm')
parser.add_argument('-p', '--program_dir', type=str, required=True, dest="p", help="Path to the Directory containing code samples")
parser.add_argument('-d', '--dataset', type=str, required=True, dest="d", help="Path to the Dataset File to be Tested")
parser.add_argument('-t', '--type', type=str, required=True, dest="t", help="Type of Dataset for Testing", choices=["train", "validation", "test"])
args = parser.parse_args() 
code_directory = args.p
testing_dataset_path = args.d

kwargs = {
    'n': 9
}
dataset = DataSet()(input_dataset_file=testing_dataset_path, output_dataset_file=None ,**kwargs)
type_of_dataset = args.t
number_of_samples = len(dataset)
timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')

temp_python_filename = "temp.py"

def get_python_files(directory):
    all_files = os.listdir(directory)
    python_files = [file for file in all_files if file.endswith('.py')]
    return python_files

def copy_file(source_file, destination_file):
    with open(source_file, 'r') as file:
        content = file.read()
    with open(destination_file, 'w') as file:
        file.write(content)
        
with open(f"scores-{type_of_dataset}-{timestamp}.txt", 'w') as log_file:
    log_file.write(f"Code Directory: {code_directory}, Dataset File: {testing_dataset_path}\n")
    log_file.write(f"Number of Testing Samples: {number_of_samples}\n\n")
    
    code_files = get_python_files(code_directory)
    for code_file in code_files:
        print(f"Testing {os.path.join(code_directory, code_file)}")
        copy_file(source_file=os.path.join(code_directory, code_file), destination_file=temp_python_filename)
        correct_count, timeout_error_count, runtime_error_count, verification_error_count = 0, 0, 0, 0 
        for idx in range(number_of_samples):
            ### get inputs
            input_sample = dataset[idx]["input"]
            ### prepare input.txt
            with open("input.txt", "w") as input_file:
                input_file.write(input_to_text_string(input_sample, **kwargs))
            ### prepare output.txt, clear contents if any
            with open("output.txt", "w") as _:
                pass
            
            ### run generated code on the given input
            run_output = run_python_file(temp_python_filename, depth=0, timeout=config['timeout'])
            print(idx)
            print(run_output)
            
            if "TIMEOUT-ERROR" in run_output:
                print("Timeout Error")
                timeout_error_count += 1
                continue

            if ("RUNTIME-ERROR" in run_output) or ("STD-ERROR" in run_output and len(run_output["STD-ERROR"])):
                print("Runtime Error / Output to Standard Error Stream")
                runtime_error_count += 1
                continue
            
            ### try to get output board from output.txt
            output_lines = []
            with open("output.txt", 'r') as f:
                output_lines = f.readlines()
            output_sample = output_from_text(output_lines, **kwargs)

            if output_sample["ERROR"] is not None: ### error while extracting output board
                print(f"Board Parsing Error for Output: {output_lines}")
                verification_error_count += 1
                continue
            else:
                ### check output board
                verification_result = MyVerifier()(input_sample, output_sample["OUTPUT"], **kwargs)
                incorrect_found = not verification_result["result"]
                if incorrect_found:
                    verification_error_count += 1
                    print("Verification Error")
                    continue

            correct_count += 1
        
        log_file.write(f"{code_file}\nCorrect: {correct_count}\nRuntime-Error: {runtime_error_count}\nTimeout-Error: {timeout_error_count}\nVerification-Error: {verification_error_count}\n\n")
        print("-"*30)
    