import matplotlib.pyplot as plt
from datasets import load_dataset
import json
import re
import numpy as np
import sys
import ast
import astor
from code_runner import test_humaneval_function
from generate_answers import set_seed

class DummyOutput:
    def write(self, text):
        pass

def get_code_from_text(model_response):
    split_model_output = model_response.split("```")
    if len(split_model_output) == 1:
        return False
    fixed_code = split_model_output[1].replace('python', '').replace('Python', '')
    return fixed_code

def get_import_lines(code: str) -> str:
    # Split the code into lines
    lines = code.split('\n')
    # Filter lines that contain import statements
    import_lines = [line for line in lines if line.strip().startswith('import') or line.strip().startswith('from')]
    # Join the import lines back into a single string
    return '\n'.join(import_lines)


def extract_function_code(code, function_name):
    """
    Extracts the specified function from the given Python code.

    Args:
    code (str): The string of Python code.
    function_name (str): The name of the function to extract.

    Returns:
    str: The code containing only the specified function or an empty string if the function is not found.
    """
    if code is None:
        return None
    try:
        # Parse the code into an AST
        tree = ast.parse(code)

        # Find the specified function in the AST
        for node in tree.body:
            if isinstance(node, ast.FunctionDef) and node.name == function_name:
                # Create a new module with just the specified function
                new_tree = ast.Module(body=[node], type_ignores=[])
                # Convert the AST back to code
                return astor.to_source(new_tree)

        # If the function is not found, return an empty string
        return None

    except Exception as e:
        print(f"An error occurred: {e}")
        return None


def get_problems_from_key(tuple_key, human_eval_data_dict, second_data_source_data_dict):
    curr_problem_list = []
    for j in range(len(tuple_key)):
        if "Human" in tuple_key[j]:
            current_problem = human_eval_data_dict[tuple_key[j]]
        else:
            current_problem = second_data_source_data_dict[tuple_key[j]]
        curr_problem_list.append(current_problem)
    return curr_problem_list

def get_multiplication_perc(tuple_key, human_eval_results, second_data_source_results, problem_success_dict,
                            curr_generation_batch, full_results=False):
    multiplication_perc = 1
    for problem_key in tuple_key:
        if "Human" in problem_key:
            curr_value = human_eval_results[problem_key]
        else:
            curr_value = second_data_source_results[problem_key]
        if type(curr_value) == list:
            curr_value = np.sum(curr_value) / len(curr_value)
        multiplication_perc *= curr_value

    comp_mult_success = 1
    for problem_key in problem_success_dict:
        curr_success_perc = np.sum(problem_success_dict[problem_key]) / len(curr_generation_batch)
        comp_mult_success *= curr_success_perc

    return multiplication_perc, comp_mult_success


def test_composition_dataset(generations, human_eval_data_dict, second_data_source_data_dict, human_eval_results,
                             second_data_source_results=None, return_full_results=False):
    if type(generations) == str:
        f = open(generations)
        all_generations = json.load(f)
    else:
        all_generations = generations

    results = dict()
    for key in all_generations:

        curr_generation_batch = all_generations[key]
        tuple_key = eval(key)
        curr_problem_list = get_problems_from_key(tuple_key, human_eval_data_dict, second_data_source_data_dict)

        full_success_list = []
        problem_success_dict = dict()
        for i in range(len(curr_generation_batch)):
            curr_code = get_code_from_text(curr_generation_batch[i])
            full_success = True

            # iterate over problems in current composition
            for k, problem in enumerate(curr_problem_list):
                current_function_code = extract_function_code(curr_code, problem['entry_point'])
                successfully_solved = test_humaneval_function(problem, current_function_code)
                full_success = full_success and successfully_solved
                curr_problem_key = str(tuple_key[k]) + f'_p{k}'
                if curr_problem_key not in problem_success_dict:
                    problem_success_dict[curr_problem_key] = []
                problem_success_dict[curr_problem_key].append(successfully_solved)

            full_success_list.append(full_success)

        full_success_perc = np.sum(full_success_list) / len(curr_generation_batch)


        sys.stdout = sys.__stdout__
        print(key)
        multiplication_perc, comp_mult_success =\
            get_multiplication_perc(tuple_key, human_eval_results, second_data_source_results, problem_success_dict,
                                    curr_generation_batch)
        if not return_full_results:
            comp_success = full_success_perc
        else:
            comp_success = full_success_list
        results[key] = {'composition_success_perc': comp_success, 'multiplication_perc': multiplication_perc,
                        'comp_mult_success': comp_mult_success}

    return results


def test_human_eval_dataset(dataset_path, data_dict, return_full_results=False):
    if type(dataset_path) == str:
        f = open(dataset_path)
        all_generations = json.load(f)
    else:
        all_generations = dataset_path

    results = dict()
    for key in all_generations:
        print(key)
        curr_generation_batch = all_generations[key]
        curr_problem = data_dict[key]

        full_success_list = []
        for i in range(len(curr_generation_batch)):
            curr_code = get_code_from_text(curr_generation_batch[i])
            curr_function_code = extract_function_code(curr_code, curr_problem['entry_point'])
            is_pass = test_humaneval_function(curr_problem, curr_function_code)
            full_success_list.append(is_pass)

        success_perc = np.sum(full_success_list) / len(curr_generation_batch)
        if not return_full_results:
            results[key] = success_perc
        else:
            results[key] = full_success_list

    return results

def get_pass_rate(results):
    pass_rate_dict = dict()
    for key in results:
        curr_results = results[key]
        for i in range(len(curr_results)):
            if key not in pass_rate_dict:
                pass_rate_dict[key] = []
            pass_rate_dict[key].append(np.sum(curr_results[:i+1]) / (i+1))
        pass_rate_dict[key] = np.array(pass_rate_dict[key])

    return pass_rate_dict


def create_pass_rate_plot(human_eval_results, comp_results):
    human_eval_pass_rate = get_pass_rate(human_eval_results)
    comp_results = {key: comp_results[key]['composition_success_perc'] for key in comp_results}
    comp_pass_rate = get_pass_rate(comp_results)
    length = len(list(comp_pass_rate.values())[0])
    comp = np.zeros(length)
    mult = np.zeros(length)
    for key in comp_pass_rate:
        tuple_key = eval(key)
        p1, p2 = tuple_key
        comp += comp_pass_rate[key]
        mult += human_eval_pass_rate[p1] * human_eval_pass_rate[p2]

    comp /= len(comp_pass_rate)
    mult /= len(comp_pass_rate)
    plt.plot(1 / comp, label='Composition')
    plt.plot(1 / mult, label='Multiplication')
    plt.xlabel('Number of samples')
    plt.ylabel('generation complexity')
    plt.legend()
    plt.show()

def main():
    f = open('datasets_and_generations/datasets/human_eval.json')
    human_eval_data_dict = json.load(f)

    human_eval_generation_path = 'datasets_and_generations/generations/human_eval_dataset_generations'
    human_eval_results = test_human_eval_dataset(human_eval_generation_path, human_eval_data_dict,
                                                 return_full_results=True)

    composite_dataset_generations = 'datasets_and_generations/generations/composition_dataset_generations'
    comp_results = test_composition_dataset(generations=composite_dataset_generations,
                                            human_eval_data_dict=human_eval_data_dict,
                                            second_data_source_results=dict(),
                                            second_data_source_data_dict=dict(),
                                            human_eval_results=human_eval_results, return_full_results=True)

    create_pass_rate_plot(human_eval_results, comp_results)


    comp_sum = 0
    mult_sum = 0
    comp_mult_sum = 0

    for key in comp_results:
        print(key)
        curr_comp_success = comp_results[key]['composition_success_perc']
        if type(curr_comp_success) == list:
            curr_comp_success = np.sum(curr_comp_success) / len(curr_comp_success)
        comp_sum += curr_comp_success
        mult_sum += comp_results[key]['multiplication_perc']
        comp_mult_sum += comp_results[key]['comp_mult_success']
        print(f"Composition success: {curr_comp_success}, Multiplication: {comp_results[key]['multiplication_perc']}, "
              f"Comp mult: {comp_results[key]['comp_mult_success']}")
        print('='*30)

    print(comp_results)
    print(human_eval_results)
    print(f'comp sum {comp_sum}')
    print(f'mult sum {mult_sum}')
    print(f'comp mult sum {comp_mult_sum}')


if __name__ == '__main__':
    set_seed(42)
    main()
