import ast
import json

import tqdm

from lcb_runner.evaluation.pass_k_utils import compute_metrics_from_results


def parse_assert_statement(statement):
    """
    Parse a Python assert statement and extract the expected output
    from the right side of the '==' operator as a string.

    :param statement: A string containing the assert statement.
    :return: The expected output from the assert statement as a string.
    """
    try:
        parsed = ast.parse(statement, mode="exec")
    except SyntaxError:
        return "Invalid syntax"

    if len(parsed.body) == 0:
        return "Empty statement"

    if not isinstance(parsed.body[0], ast.Assert):
        return "Not an assert statement"

    comparison = parsed.body[0].test

    if not isinstance(comparison, ast.Compare) or not isinstance(
        comparison.ops[0], ast.Eq
    ):
        return "Not an equality assertion"

    # Extract and return the right side of the '==' operator as a string
    return ast.get_source_segment(statement, comparison.comparators[0])


def check_testcase_output(testcase_str, expected_output):

    if len(testcase_str.splitlines()) > 1:
        for line in testcase_str.splitlines():
            if line.startswith("#"):
                continue
            if "assert" in line:
                testcase_str = line
                break

    testcase_str = testcase_str.strip()

    if "assert" in testcase_str:
        testcase_output_str = str(parse_assert_statement(testcase_str))

    else:
        testcase_output_str = testcase_str

    global_result = None

    try:
        testcase_output_eval = eval(testcase_output_str)
    except:
        global_result = False
        # print("Failed to eval testcase output", testcase_output_str)
        # breakpoint()

    try:
        expected_output_eval = json.loads(expected_output)
    except:
        global_result = False
        print("Failed to eval expected testcase output", expected_output)

    if global_result is None:
        global_result = testcase_output_eval == expected_output_eval

    return global_result


def test_output_metrics(
    samples,
    generations,
    k_list=[1, 5],
):
    num_samples = len(samples)
    results = []
    for idx in tqdm.tqdm(list(range(num_samples))):
        idx_results = []
        sample = samples[idx]
        extracted_generation_list = generations[idx]
        for extracted_generation in extracted_generation_list:
            global_result = check_testcase_output(
                extracted_generation, sample["output"]
            )
            idx_results.append([global_result])
        results.append(idx_results)

    results = {result_idx: results[result_idx] for result_idx in range(len(results))}

    metrics = compute_metrics_from_results(results, k_list=k_list)

    return [metrics, results]
