import ast
import inspect
import json
import multiprocessing
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
from traceback import print_exc

from rich.console import Console
from rich.syntax import Syntax
from tqdm import tqdm

from evalplus.data.mbpp import (
    MBPP_PLUS_VERSION,
    get_mbpp,
    get_mbpp_plus,
    get_mbpp_plus_hash,
)
from evalplus.eval import is_floats
from evalplus.eval._special_oracle import (
    MBPP_OUTPUT_NOT_NONE_TASKS,
    MBPP_OUTPUT_SET_EQ_TASKS,
    _digit_distance_nums,
    _surface_Area,
)
from evalplus.evaluate import get_groundtruth

MBPP_TEST_TEMPLATE = """\
{imports}

{aux_fn}

inputs = {inputs}
results = {results}
for i, (inp, exp) in enumerate(zip(inputs, results)):
    {assertion}
"""

MBPP_CROSSCHECK_TEMPLATE = """\
import numpy as np
from math import inf

{aux_fn}

{ref_func}

inputs = {inputs}
for i, inp in enumerate(inputs):
    assertion({entry_point}(*inp), ref_func(*inp), {atol})
"""

ASSERTION_FN = f"""\
import numpy as np

{inspect.getsource(is_floats)}

def assertion(out, exp, atol):
    exact_match = out == exp

    if atol == 0 and is_floats(exp):
        atol = 1e-6
    if not exact_match and atol != 0:
        assert np.allclose(out, exp, rtol=1e-07, atol=atol)
    else:
        assert exact_match, f"out: {{out}}, exp: {{exp}}"
"""


def synthesize_test_code(task_id, entry_point, inputs, results, ref_func, atol):
    # dataset size optimization for large outputs
    if entry_point in ("combinations_colors", "freq_count", "get_coordinates"):
        return task_id, MBPP_CROSSCHECK_TEMPLATE.format(
            aux_fn=ASSERTION_FN,
            inputs=inputs,
            ref_func=ref_func.replace(f" {entry_point}(", " ref_func("),
            entry_point=entry_point,
            atol=atol,
        )

    # default settings
    imports = set(["import numpy as np", "from math import inf"])
    aux_fn = ASSERTION_FN
    assertion = f"assertion({entry_point}(*inp), exp, {atol})"

    # ================================================ #
    # ============== special oracles ================= #

    if entry_point in MBPP_OUTPUT_SET_EQ_TASKS:
        aux_fn = f"""\
{inspect.getsource(is_floats)}

def assertion(out, exp, atol):
    if atol == 0 and is_floats(exp):
        atol = 1e-6
    out = set(out)
    exp = set(exp)
    if out != exp and atol != 0:
        assert np.allclose(out, exp, rtol=1e-07, atol=atol)
    else:
        assert out == exp, f"out: {{out}}, exp: {{exp}}"
"""
    elif entry_point in MBPP_OUTPUT_NOT_NONE_TASKS:
        aux_fn = f"""\
def assertion(out, exp, atol):
    if isinstance(out, bool):
        exact_match = out == exp
    else:
        exact_match = exp == (out is not None)
"""
    elif entry_point == "surface_Area":
        imports.add("import math")
        aux_fn = (
            inspect.getsource(_surface_Area)
            + "\n"
            + """\
def assertion(out, exp_0, exp_1, atol):
    assert abs(out - exp_0) <= atol or abs(out - exp_1) <= atol
"""
        )
        assertion = f"assertion(surface_Area(*inp), exp, _surface_Area(*inp), {atol})"

    elif entry_point == "digit_distance_nums":
        aux_fn = (
            inspect.getsource(_digit_distance_nums)
            + "\n"
            + """\
def assertion(out, exp_0, exp_1, atol):
    assert out == exp_0 or out == exp_1
"""
        )
        assertion = f"assertion(digit_distance_nums(*inp), exp, _digit_distance_nums(*inp), {atol})"

    # ============== special oracles ================= #
    # ================================================ #

    test_code = MBPP_TEST_TEMPLATE.format(
        imports="\n".join(imports),
        aux_fn=aux_fn,
        inputs=inputs,
        results=results,
        entry_point=entry_point,
        assertion=assertion,
    )

    return task_id, test_code


def deduplicate(inputs, results):
    assert len(inputs) == len(results)
    unique_input_strs = set([f"{x}" for x in inputs])

    new_inputs, new_results = [], []
    for inp, res in zip(inputs, results):
        inp_str = f"{inp}"
        if inp_str in unique_input_strs:
            new_inputs.append(inp)
            new_results.append(res)
            unique_input_strs.remove(inp_str)

    return new_inputs, new_results


def main():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--debug-tasks", nargs="+", default=[], type=int)

    args = parser.parse_args()
    console = Console()

    if hasattr(sys, "set_int_max_str_digits"):
        sys.set_int_max_str_digits(int(10e8))

    plus_problems = get_mbpp_plus(mini=False)
    dataset_hash = get_mbpp_plus_hash()

    original_mbpp = get_mbpp()

    compatible_problems = {}
    expected_outputs = get_groundtruth(
        plus_problems, dataset_hash, MBPP_OUTPUT_NOT_NONE_TASKS
    )

    # debugging: monitoring test code size
    id2bytes = {}

    n_workers = max(1, multiprocessing.cpu_count() // 4)
    with ProcessPoolExecutor(max_workers=n_workers) as executor:
        futures = []
        for task_id, plus_form in tqdm(plus_problems.items()):
            # expected MBPP task_id is numbers directly
            # i.e., "666" instead of "Mbpp/666"
            # But in EvalPlus the task_id is "Mbpp/666"
            task_id_int = int(task_id.split("/")[-1])
            if args.debug_tasks and task_id_int not in args.debug_tasks:
                continue

            compatible_form = {
                "task_id": task_id_int,
                "code": plus_form["canonical_solution"],
                "prompt": original_mbpp[str(task_id_int)]["prompt"],
                "source_file": original_mbpp[str(task_id_int)]["source_file"],
                "test_imports": original_mbpp[str(task_id_int)]["test_imports"],
                "test_list": original_mbpp[str(task_id_int)]["test_list"],
            }
            compatible_problems[task_id_int] = compatible_form

            inputs = (
                plus_form["base_input"] + plus_form["plus_input"]
                if len(plus_form["plus_input"]) > 0
                else plus_form["base_input"]
            )
            results = (
                expected_outputs[task_id]["base"] + expected_outputs[task_id]["plus"]
            )

            inputs, results = deduplicate(inputs, results)

            assert len(inputs) == len(results)
            atol = plus_form["atol"]

            futures.append(
                executor.submit(
                    synthesize_test_code,
                    task_id_int,
                    plus_form["entry_point"],
                    inputs,
                    results,
                    compatible_form["code"],
                    atol,
                )
            )

        for future in tqdm(as_completed(futures), total=len(plus_problems)):
            task_id, test_code = future.result()
            # syntax check of test_code
            ast.parse(test_code)
            # ground-truth check
            task = plus_problems[f"Mbpp/{task_id}"]
            exec_code = (
                task["prompt"] + "\n" + task["canonical_solution"] + "\n" + test_code
            )

            # run the code in a subprocess
            def test():
                try:
                    exec(exec_code, globals())
                except Exception:
                    print_exc()
                    raise

            p = multiprocessing.Process(target=test)
            p.start()
            p.join(timeout=20)
            assert not p.is_alive(), f"Timeout for Mbpp/{task_id}!"
            p.terminate()
            p.join()
            if p.exitcode != 0:
                console.print(Syntax(exec_code, "python", line_numbers=True))
                raise RuntimeError(f"Error for Mbpp/{task_id}")

            id2bytes[task_id] = len(test_code.encode("utf-8"))
            compatible_problems[task_id]["test"] = test_code

    # print the top-10 largest test code
    print("Top-10 largest test code comes from problems (in megabytes):")
    for task_id, size in sorted(id2bytes.items(), key=lambda x: x[1], reverse=True)[
        :10
    ]:
        print(f"{task_id}:\t{size / 1024 / 1024:.2f}mb")

    if args.debug_tasks:
        for problem in compatible_problems.values():
            print("--- debugging:", problem["task_id"])
            print('"""\n' + problem["prompt"] + '\n"""\n' + problem["code"])
            test_code = problem["test"]
            if len(test_code) <= 1024:
                print(test_code)
            else:
                print(problem["test"][:1024], "...")
                print("...", problem["test"][-1024:])
    else:
        with open(f"MbppPlus-OriginFmt-{MBPP_PLUS_VERSION}.jsonl", "w") as f:
            for problem in compatible_problems.values():
                f.write(json.dumps(problem) + "\n")


if __name__ == "__main__":
    main()
