from utils import get_outdir, load_from_saved, DummyArgs
import argparse
import ast
import re
import signal


# Define a timeout handler
def timeout_handler(signum, frame):
    raise TimeoutError

# Set the signal alarm
signal.signal(signal.SIGALRM, timeout_handler)

def find_length_two_structure_old(s):
    # Regular expression to find tuples or lists of length 2
    pattern = r'\(([^,]+),([^,]+)\)|\[(\d+),(\d+)\]'
    match = re.search(pattern, s)
    
    if match:
        # Extract matched elements based on the captured groups
        if match.group(1) and match.group(2):
            return [match.group(1).strip(), match.group(2).strip()]
        elif match.group(3) and match.group(4):
            return [match.group(3).strip(), match.group(4).strip()]
    return None

def find_length_two_structure(s):
    # Regular expression to find tuples or lists of length 2
    pattern = r'\(([^\,]+),\s*([^\,]+)\)|\[([^\,]+),\s*([^\,]+)\]'
    match = re.search(pattern, s)

    if match:
        # Extract the elements from the matched groups
        elements = [match.group(i) for i in range(1, 5) if match.group(i)]
        return [element.strip() for element in elements]

    return None


def compute_function(func_str, x):
    # Compile and execute the function string
    exec(func_str, globals())
    
    # Extract the function name from the function string
    func_name = func_str.split('(')[0].split()[-1]
    func = globals()[func_name]
    
    # Check if x is a string that represents a list and convert it
    if isinstance(x, str):
        try:
            if x.startswith('['):
                # Is a list, handle this
                x = x.strip('[').strip(']').split(', ')
                for i in range(len(x)):
                    try:
                        x[i] = float(x[i]) 
                    except Exception as e:
                        pass

        except ValueError:
            # Handle the case where x is not a valid Python literal
            pass

    # Set the timeout to 2 seconds
    signal.alarm(2)

    #print("x: ", x)
    #print(type(x), x[0])
    #print("func_str: ", func_str)

    # If x is now a list, unpack its elements as arguments
    try:
        if isinstance(x, list):
            #return func(*x)
            x = [ast.literal_eval(x_i) for x_i in x]
            result = func(x)
        else:
            result = func(x)
        signal.alarm(0)  # Reset the alarm
        return result
    except TimeoutError:
        return "Function execution failed due to timeout"
    except Exception as e:
        signal.alarm(0)  # Reset the alarm if other exceptions occur
        return str(e)

def get_inputs(outputs):
    inputs = []
    for i, output in enumerate(outputs):
        if '\nINPUT:' not in output or 'END' not in output:
            # Now, try to extract length 2 string
            length_two_structure = find_length_two_structure(output)
            inputs.append(length_two_structure)
            if length_two_structure is None:
                #print("Failed to parse: ")
                #print(output)
                pass
            continue
        input_start = output.index('\nINPUT:') + len('\nINPUT:')
        input_end = output.index('END')
        inpt = output[input_start:input_end]
        inputs.append(inpt)
    return inputs

def evaluate_hacking_accuracy(hacking_model, generative_model, n_examples = 100, outdir = 'experiment_results', difficulty = 'moderate', documentation = False, modification_type = None):
    retrieve_args = DummyArgs(
        experiment_name = 'produce-inputs',
        model = hacking_model,
        output_model = generative_model,
        finished = True,
        n_examples = n_examples,
        outdir = outdir,
        difficulty = difficulty,
        use_documentation = documentation,
        modification_type = modification_type,
    )
    if modification_type is not None:
        retrieve_args.modification_type = modification_type
    load_outdir = get_outdir(retrieve_args, require_exists = True)
    results = load_from_saved(load_outdir, 'results')
    _, outputs, metadata = results
    hacking_inputs = get_inputs(outputs)
    correct = 0
    error_rate = 0
    hacking_none = 0
    incorrect = 0
    for i, hacking_input in enumerate(hacking_inputs):
        if hacking_input is None:
            #print("Hacking input none")
            hacking_none += 1
            continue
        function = metadata[i]['function']
        target_output = metadata[i]['output']
        try:
            hacking_output = compute_function(function, hacking_input)
            target_output = target_output.replace('"', '').replace("'", '').replace(' ', '').replace('.', '')
            hacking_output = hacking_output.replace('"', '').replace("'", '').replace(' ', '').replace('.', '')
            #print(hacking_input, hacking_output, target_output)
            if hacking_output == target_output:
                correct += 1
            else:
                incorrect += 1
        except Exception as e:
            #print(f"Error: {e}")
            error_rate += 1
    n_examples = len(hacking_inputs)
    print(f"Accuracy: {correct}/{n_examples}")
    print(f"Incorrect: {incorrect}/{n_examples}")
    print(f"Error rate: {error_rate}/{n_examples}")
    print(f"No input: {hacking_none}/{n_examples}")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--hacking_model', type = str, default = 'gpt-4-0125-preview')
    parser.add_argument('--difficulty', type = str, default = 'moderate')
    parser.add_argument('--documentation', action = 'store_true')
    parser.add_argument('--modification_type', type = str, default = None)
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    if args.modification_type is not None:
        assert not args.documentation
    for difficulty in ['easy', 'moderate', 'hard']:
        print(f"Difficulty: {difficulty}")
        evaluate_hacking_accuracy(args.hacking_model, 'gpt-4-0125-preview', difficulty = difficulty, documentation = args.documentation, modification_type = args.modification_type)
    #evaluate_hacking_accuracy(args.hacking_model, 'gpt-4-0125-preview', difficulty = args.difficulty, documentation = args.documentation, modification_type = args.modification_type)
