import re
import tempfile
from ast import literal_eval
from pathlib import Path
from typing import Dict, List, Optional, Tuple

from llm_mcts.data_types import Action, Grid, GridType
from llm_mcts.tasks.arc.eval import untrusted_check

TEST_TRANSFORM = (
    Path(__file__) / ".." / "unittest_templates" / "test_transform.py"
).resolve()
TEST_UNITTEST = (
    Path(__file__) / ".." / "unittest_templates" / "test_question.py"
).resolve()


def parse_funcname_from_answer(answer_code: str) -> Optional[str]:
    funcname_pat = re.compile(r"^def\s+([a-zA-Z_][a-zA-Z_0-9]*)", re.MULTILINE)

    funcnames = funcname_pat.findall(answer_code)
    if len(funcnames) == 0:
        return None
    else:
        if "transform" in funcnames:
            return "transform"
        return funcnames[-1]


def prepare_test_code(
    test_code: str,
    funcname: str,
    input: Grid,
    output: Grid,
    output_fpath: Optional[Path],
) -> str:
    """
    Populate template test code with grid input, output and task function name
    """
    input_grid_pat = re.compile(r"^__PROBLEM_INPUT__\s*=.*$", re.MULTILINE)
    output_grid_pat = re.compile(r"^__PROBLEM_OUTPUT__\s*=.*$", re.MULTILINE)
    task_func_pat = re.compile(r"^__TASK_FUNC__\s*=.*$", re.MULTILINE)
    output_file_pat = re.compile(r"^__OUTPUT_FILE__\s*=.*$", re.MULTILINE)

    input_str = f"__PROBLEM_INPUT__ = {repr(input)}"
    output_str = f"__PROBLEM_OUTPUT__ = {repr(output)}"
    task_func_str = f"__TASK_FUNC__ = {str(funcname)}"
    output_file_str = (
        f"__OUTPUT_FILE__ = '{str(output_fpath.resolve()) if output_fpath else ''}'"
    )

    for repl, pat in zip(
        (input_str, output_str, task_func_str, output_file_str),
        (input_grid_pat, output_grid_pat, task_func_pat, output_file_pat),
    ):
        test_code = pat.sub(repl, test_code)

    return test_code


def generate_code_and_test(
    example: Dict[GridType, Grid],
    answer_code: str,
    kind: Action,
    output_fpath: Optional[Path] = None,
) -> Optional[Tuple[str, str]]:
    if kind == "transform":
        test_code = TEST_TRANSFORM.read_text()
    elif kind == "question":
        test_code = TEST_UNITTEST.read_text()
    else:
        raise NotImplementedError()

    funcname = parse_funcname_from_answer(answer_code)
    if funcname is None:
        return None

    test_code = prepare_test_code(
        test_code, funcname, example["input"], example["output"], output_fpath
    )
    return answer_code, test_code


def eval_if_test_pass(
    example: Dict[GridType, Grid],
    answer_code: str,
    kind: Action,
) -> Tuple[bool, Optional[List[List[int]]]]:
    """
    For transform test, we can let test code output the transformed grid to output_fpath.
    We can parse its content by repr(output_fpath.read_text())
    """
    with tempfile.TemporaryDirectory() as temp_dir:
        output_fpath = Path(temp_dir) / "output.txt"
        parsed = generate_code_and_test(example, answer_code, kind, output_fpath)
        if parsed is None:
            return False, None

        code, test_code = parsed
        stat, details = untrusted_check(code, test_code)

        output = None
        if output_fpath.exists():
            try:
                output = literal_eval(output_fpath.read_text())
            except:
                pass

        return stat == "pass", output
