from datasets import load_dataset
import re
from itertools import combinations
import json
import gzip
import random

from generate_answers import set_seed

q_percentage = {0: 56.25, 1: 87.5, 2: 43.75, 3: 100.0, 4: 100.0, 5: 6.25, 6: 90.625, 7: 100.0, 8: 78.125, 9: 3.125, 10: 12.5, 11: 100.0, 12: 100.0, 13: 100.0, 14: 100.0, 15: 100.0, 16: 100.0, 17: 100.0, 18: 93.75, 19: 87.5, 20: 100.0, 21: 100.0, 22: 100.0, 23: 100.0, 24: 37.5, 25: 12.5, 26: 0.0, 27: 100.0, 28: 100.0, 29: 100.0, 30: 100.0, 31: 100.0, 32: 0.0, 33: 0.0, 34: 100.0, 35: 100.0, 36: 81.25, 37: 6.25, 38: 3.125, 39: 0.0, 40: 87.5, 41: 3.125, 42: 96.875, 43: 100.0, 44: 100.0, 45: 100.0, 46: 87.5, 47: 100.0, 48: 100.0, 49: 100.0, 50: 6.25, 51: 100.0, 52: 100.0, 53: 100.0, 54: 0.0, 55: 90.625, 56: 100.0, 57: 78.125, 58: 100.0, 59: 15.625, 60: 100.0, 61: 100.0, 62: 34.375, 63: 96.875, 64: 0.0, 65: 3.125, 66: 100.0, 67: 50.0, 68: 93.75, 69: 37.5, 70: 96.875, 71: 100.0, 72: 87.5, 73: 0.0, 74: 100.0, 75: 3.125, 76: 34.375, 77: 0.0, 78: 6.25, 79: 81.25, 80: 100.0, 81: 90.625, 82: 100.0, 83: 100.0, 84: 0.0, 85: 9.375, 86: 100.0, 87: 21.875, 88: 21.875, 89: 93.75, 90: 65.625, 91: 0.0, 92: 100.0, 93: 3.125, 94: 100.0, 95: 25.0, 96: 100.0, 97: 100.0, 98: 34.375, 99: 9.375, 100: 0.0, 101: 68.75, 102: 0.0, 103: 25.0, 104: 100.0, 105: 21.875, 106: 0.0, 107: 100.0, 108: 0.0, 109: 0.0, 110: 0.0, 111: 28.125, 112: 93.75, 113: 0.0, 114: 93.75, 115: 28.125, 116: 96.875, 117: 93.75, 118: 0.0, 119: 0.0, 120: 0.0, 121: 100.0, 122: 100.0, 123: 100.0, 124: 3.125, 125: 0.0, 126: 0.0, 127: 9.375, 128: 100.0, 129: 0.0, 130: 0.0, 131: 0.0, 132: 0.0, 133: 28.125, 134: 0.0, 135: 0.0, 136: 0.0, 137: 3.125, 138: 100.0, 139: 25.0, 140: 0.0, 141: 6.25, 142: 0.0, 143: 0.0, 144: 6.25, 145: 0.0, 146: 65.625, 147: 100.0, 148: 78.125, 149: 93.75, 150: 100.0, 151: 15.625, 152: 100.0, 153: 71.875, 154: 31.25, 155: 100.0, 156: 96.875, 157: 78.125, 158: 100.0, 159: 0.0, 160: 0.0, 161: 9.375, 162: 100.0, 163: 0.0}
NUM_Q_IN_COMP = 3


def prompt_composition(q_list):
    input_pattern = re.compile(r'assert candidate\((.*?)\)')
    inputs_list = []
    prompts_string = ''
    for i, q in enumerate(q_list):
        matches = input_pattern.findall(q['test'])
        if len(matches) == 0:
            return None
        curr_input = matches[0]
        inputs_list.append(curr_input)
        curr_prompt = q['prompt']
        prompts_string += f'this is f{i}: {curr_prompt} \n'

    intro = 'Complete the following functions. Put the code for all functions together with any needed imports.'
    instruction = f'finally, write code which computes and returns '
    function_output_string = ''
    for i, curr_input in enumerate(inputs_list):
        function_output_string += f'f{i}({curr_input})'
        if i < len(inputs_list) - 1:
            function_output_string += '*'
    instruction += function_output_string

    full_question_composition = f"{intro}\n {prompts_string} \n {instruction}"

    return full_question_composition


def is_output_number(problem):
    test = problem['test']
    # Define the regex pattern to match the output in the assert statements
    pattern = re.compile(r'assert\s+candidate\([^\)]+\)\s*==\s*([\d\.]+)')

    # Find all matches of the pattern in the problem string
    matches = pattern.findall(test)

    # If there is at least one match, return True (indicating output is a number)
    return bool(matches)


def receives_single_number(problem):
    prompt = problem['prompt']

    pattern = re.compile(r'def \w+\((\w+)\)')

    # Check if there is a match in the function description
    match = pattern.search(prompt)

    # Return True if a match is found, indicating the function receives a single float number
    return bool(match)


def stream_jsonl(filename: str):
    """
    Parses each jsonl line and yields it as a dictionary
    """
    if filename.endswith(".gz"):
        with open(filename, "rb") as gzfp:
            with gzip.open(gzfp, 'rt') as fp:
                for line in fp:
                    if any(not x.isspace() for x in line):
                        yield json.loads(line)
    else:
        with open(filename, "r") as fp:
            for line in fp:
                if any(not x.isspace() for x in line):
                    yield json.loads(line)


def main():
    composition_dict = dict()
    human_eval_data = load_dataset("openai/openai_humaneval")
    human_eval_dict = {q['task_id']: q for q in human_eval_data['test']}

    human_eval_dict = dict(random.sample(human_eval_dict.items(), 50))

    random.seed(42)
    # shuffle the dict
    tup_list = list(human_eval_dict.items())
    random.shuffle(tup_list)
    human_eval_dict = dict(tup_list)
    combinations_q = combinations(human_eval_dict.keys(), NUM_Q_IN_COMP)

    for combo in combinations_q:
        str_combo = str(combo)
        q_combo = [human_eval_dict[k] for k in combo]
        prompt_comp = prompt_composition(q_combo)
        if prompt_comp is None:
            continue

        composition_dict[str_combo] = dict()
        composition_dict[str_combo]['prompt'] = prompt_comp

    composition_dict = dict(random.sample(composition_dict.items(), min(500, len(composition_dict))))
    file_path = '../datasets_and_generations/datasets/three_part_composition_dataset_human_eval.json'
    with open(file_path, 'w') as json_file:
        json.dump(composition_dict, json_file)


if __name__ == '__main__':
    set_seed(42)
    main()
