from __future__ import annotations

import asyncio
import hashlib
import logging
import os
import re
import tempfile
from typing import Optional, Union

import attrs
from cattrs.preconf.json import make_converter

json_converter = make_converter()

logger = logging.getLogger(__name__)


@attrs.frozen
class CompilationFailure:
    out: CommandOut


def get_default_num_workers():
    MIN_SET_ASIDE_CORES = 10
    FRAC_CORES = 0.9

    total_cores = os.cpu_count()
    assert total_cores is not None
    worker_cores = min(
        round(FRAC_CORES * total_cores), total_cores - MIN_SET_ASIDE_CORES
    )
    max_workers = max(1, worker_cores)
    logger.debug("max workers = ", max_workers)
    return max_workers


run_semaphore = asyncio.Semaphore(get_default_num_workers())


@attrs.frozen
class CommandReturn:
    stdout: str
    stderr: str
    return_code: int


@attrs.frozen
class CommandTimeout:
    ...


CommandOut = Union[CommandReturn, CommandTimeout]


def command_out_is_ok(x: CommandOut) -> CommandReturn:
    return isinstance(x, CommandReturn) and x.return_code == 0


async def simple_command_run(
    command: str,
    timeout_seconds: float = 2.0,
    memory_limit_mb: float = 1024.0,
    inp: Optional[str] = None,
) -> CommandOut:
    # Convert memory_limit from MB to KB
    memory_limit_kb = int(memory_limit_mb * 1024)

    async with run_semaphore:
        process = await asyncio.create_subprocess_exec(
            "bash",
            "-c",
            # Use ulimit to set the virtual memory limit
            f"ulimit -v {memory_limit_kb}; {command}",
            stdin=asyncio.subprocess.PIPE,
            stdout=asyncio.subprocess.PIPE,
            stderr=asyncio.subprocess.PIPE,
            text=False,
        )
        try:
            stdout_bytes, stderr_bytes = await asyncio.wait_for(
                process.communicate(
                    input=inp.encode("utf-8") if inp is not None else None
                ),
                timeout=timeout_seconds,
            )
            stdout = stdout_bytes.decode("utf-8", errors="replace")
            stderr = stderr_bytes.decode("utf-8", errors="replace")
            return CommandReturn(
                stdout=stdout, stderr=stderr, return_code=process.returncode
            )
        except asyncio.TimeoutError:
            return CommandTimeout()
        finally:
            try:
                process.kill()
            except Exception:
                ...


async def read_output(stream, output_list):
    async for chunk in stream:
        output_list.append(chunk.decode())


compile_dir = os.path.expanduser("~/.cache/_cpp_compile/")
os.makedirs(compile_dir, exist_ok=True)


@attrs.frozen
class CompiledProgram:
    _code_hash: str = attrs.field(alias="_code_hash")

    @property
    def code_hash(self) -> str:
        return self._code_hash

    @property
    def binary_path(self) -> str:
        return f"/{compile_dir}/{self.code_hash}.out"

    @property
    def error_path(self) -> str:
        return f"/{compile_dir}/{self.code_hash}_err.json"

    @classmethod
    async def compile_cpp_code(
        cls, code: str
    ) -> Union[CompiledProgram, CompilationFailure]:
        salt = "cpp_compile_no_c"
        m = hashlib.sha256()
        m.update((code + salt).encode("utf-8"))
        code_hash = m.hexdigest()

        out = cls(_code_hash=code_hash)

        try:
            with open(out.error_path, "r") as f:
                return json_converter.loads(f.read(), CompilationFailure)
        except FileNotFoundError:
            ...

        if os.path.exists(out.binary_path):
            return out

        with tempfile.TemporaryDirectory(prefix=compile_dir) as temp:
            cpp_path = f"{temp}/file.cpp"
            temp_binary_path = f"{temp}/file.out"
            with open(cpp_path, "w") as f:
                f.write(code)

            compile_command = [
                "g++",
                "-std=c++20",
                "-O3",
                cpp_path,
                "-o",
                temp_binary_path,
            ]

            compile_out = await simple_command_run(" ".join(compile_command))

            if not command_out_is_ok(compile_out):
                fail = CompilationFailure(compile_out)
                err_s = json_converter.dumps(fail)
                with open(out.error_path, "w") as f:
                    f.write(err_s)
                return fail

            # mv approach to avoid collide
            os.rename(temp_binary_path, out.binary_path)

        return out

    async def run_input(
        self,
        inp: Optional[str],
        timeout_seconds: float = 0.5,
        memory_limit_mb: float = 1_024.0,
        override_cache: bool = False,
    ) -> CommandOut:
        # FIXME: some test cases inputs in original APPS are lists instead of strings
        inp = str(inp)
        # tentatively removing cacheing with Redis
        # key = cache_key(
        #     dict(code_hash=self.code_hash, inp=inp, timeout_seconds=timeout_seconds, memory_limit_mb=memory_limit_mb),
        #     "code_exec_no_c",
        # )

        # if not override_cache:
        #     cache_result = await RedisWrapper.singleton().read(key)
        #     if cache_result is not None:
        #         return json_converter.loads(cache_result, CommandOut)

        out = await simple_command_run(
            self.binary_path,
            timeout_seconds=timeout_seconds,
            memory_limit_mb=memory_limit_mb,
            inp=inp,
        )
        # await RedisWrapper.singleton().write(key, json_converter.dumps(out, unstructure_as=CommandOut))
        return out

    async def run_inputs(
        self,
        inps: list[str],
        timeout_seconds: float = 0.5,
        memory_limit_mb: float = 1_024.0,
        override_cache: bool = False,
    ) -> list[CommandOut]:
        # key = cache_key(
        #     dict(code_hash=self.code_hash, inps=inps, timeout_seconds=timeout_seconds, memory_limit_mb=memory_limit_mb),
        #     "code_exec_many",
        # )

        # if not override_cache:
        #     cache_result = await RedisWrapper.singleton().read(key)
        #     if cache_result is not None:
        #         return json_converter.loads(cache_result, list[CommandOut])

        out = await asyncio.gather(
            *(
                self.run_input(
                    inp=inp,
                    timeout_seconds=timeout_seconds,
                    memory_limit_mb=memory_limit_mb,
                    override_cache=override_cache,
                )
                for inp in inps
            )
        )
        # await RedisWrapper.singleton().write(key, json_converter.dumps(out, unstructure_as=list[CommandOut]))
        return out


from enum import Enum, auto


class TestCaseOutcome(Enum):
    INVALID = auto()
    CORRECT = auto()
    INCORRECT = auto()
    TIMEOUT = auto()

    def broadly_incorrect(self):
        return self in [TestCaseOutcome.INCORRECT, TestCaseOutcome.TIMEOUT]


def indent_into_function(
    lines: list[str],
    name: str = "f____",
    args_names: list[str] = [],
    call: bool = False,
) -> list[str]:
    """Indent a list of lines into a function definition.

    import math
    print(input())
    ->
    import math
    def f____():
        print(input())"""

    def import_filter(line):
        return line.startswith("import ") or line.startswith("from ")

    import_lines = [line for line in lines if import_filter(line)]
    other_lines = [line for line in lines if not import_filter(line)]
    declaration = f"def {name}({', '.join(args_names)}):"
    out = import_lines + [declaration] + ["    " + line for line in other_lines]
    if call:
        out.append(f"{name}()")
    return out


import ast


def is_float(text):
    try:
        text = float(text)
        return True
    except:
        return False


def check(pred, answer):
    try:
        pred = pred.strip()
        answer = answer.strip()
    except:
        return False

    if pred == answer:
        return True

    # normalize list
    # overly long strings can introduce segmenation fault in ast.literal_eval
    if answer.startswith("[") and answer.endswith("]") and len(answer) < 1e6:
        try:
            answer = ast.literal_eval(answer)
            answer = "\n".join(answer).strip()
        except:
            answer = answer

    if pred.startswith("[") and pred.endswith("]") and len(pred) < 1e6:
        try:
            pred = ast.literal_eval(pred)
            pred = "\n".join(pred).strip()
        except:
            pass

    # list special judge
    try:
        if "\n" in pred.strip() and "\n" in answer.strip():
            pred = [i.strip() for i in pred.strip().split("\n")]
            answer = [i.strip() for i in answer.strip().split("\n")]
            flag = True
            for i_pred, i_answer in zip(pred, answer):
                if i_pred != i_answer:
                    if not (
                        is_float(i_pred)
                        and is_float(i_answer)
                        and abs(float(i_pred) - float(i_answer)) < 1e-4
                    ):
                        flag = False
                        break
            return flag
    except:
        pass

    try:  # float special judge
        pred = float(pred)
        answer = float(answer)
        flag1 = abs(pred - answer) < (max(abs(answer), abs(pred)) * 1e-4)
        flag2 = abs(pred - answer) < 1e-4
        flag = flag1 | flag2
        return flag
    except:
        pass

    flag = pred == answer
    return flag


def get_outcome(correct_output: Optional[str], cpp_outcome: CommandOut):
    if correct_output is None:
        return TestCaseOutcome.INVALID
    elif isinstance(cpp_outcome, CommandTimeout):
        return TestCaseOutcome.TIMEOUT
    elif command_out_is_ok(cpp_outcome) and check(correct_output, cpp_outcome.stdout):
        return TestCaseOutcome.CORRECT
    else:
        return TestCaseOutcome.INCORRECT


@attrs.frozen
class RanTestCase:
    input: str
    actual_output_full: CommandOut
    outcome: TestCaseOutcome
    expected_output: Optional[str] = None

    @property
    def actual_output(self):
        if isinstance(self.actual_output_full, CommandTimeout):
            return "TIMEOUT"
        return self.actual_output_full.stdout


async def get_test_case_outcomes(problem, solution, test_cases: list[dict]):
    compiled = await CompiledProgram.compile_cpp_code(solution)
    if isinstance(compiled, CompilationFailure):
        # assert not isinstance(compiled.out, CommandTimeout), "TODO"
        return [
            {
                "input": case["input"],
                "output": case["output"],
                "pred": "compile error",
                "flag": False,
            }
            for case in test_cases
        ]

    run_cpp_outputs = await compiled.run_inputs(
        inps=[case["input"] for case in test_cases]
    )

    outcomes, complete_outcomes = [], []
    for correct_output, cpp_output, input in zip(
        [case["output"] for case in test_cases],
        run_cpp_outputs,
        [case["input"] for case in test_cases],
    ):
        outcome = get_outcome(correct_output, cpp_output)
        complete_outcomes.append(
            RanTestCase(
                input=input,
                actual_output_full=cpp_output,
                outcome=outcome,
                expected_output=correct_output,
            )
        )
        outcomes.append(outcome)
    # return outcomes, complete_outcomes

    res = [
        {
            "input": item.input,
            "output": item.expected_output,
            "pred": item.actual_output_full.stdout
            if isinstance(item.actual_output_full, CommandReturn)
            else "timeout",
            "flag": item.outcome == TestCaseOutcome.CORRECT,
        }
        for item in complete_outcomes
    ]

    for i in res:
        assert isinstance(i, dict)
    return res


async def debug():
    code = """#include <iostream>
#include <string>

int main() {
    std::string input;
    std::getline(std::cin, input);
    std::cout << input << std::endl;
    return 0;
}
"""
    test_cases = [{"input": "1", "output": "1"}]
    res = await asyncio.gather(*[get_test_case_outcomes("", code, test_cases)])
    return res


from tqdm.asyncio import tqdm_asyncio


async def batched_execute_code_test_case(batched_data):
    batched_res = await tqdm_asyncio.gather(
        *[
            get_test_case_outcomes(i["question"], i["solution"], i["test_cases"])
            for i in batched_data
        ]
    )

    return batched_res


if __name__ == "__main__":
    res = asyncio.run(debug())
    print(res)
