import json
import os
import subprocess
import shutil
import signal
from pathlib import Path

from livebench.lcb_runner.evaluation.base import GenerationCorrectnessEvaluator, LCBCorrectnessCriteria

JAVA_SOLUTION_TEMPLATE = """
import java.util.*;

{generation}
"""

class RunnerStatus:
    FAILED = 1
    ERROR = -2
    TIMEOUT = -1
    OK = 0

def write_to_file(path: Path, data: bytes | str) -> None:
        if isinstance(data, bytes):
            with open(path, 'wb') as f:
                f.write(data)
        else:
            with open(path, 'wt') as f:
                f.write(data)


def escape_file_path(path_component: str) -> str:
    return path_component.replace('/','_').replace(':', '_').replace(' ', '_')

class JavaCorrectnessEvaluator(GenerationCorrectnessEvaluator):

    def __init__(self, base_dir: Path) -> None:
        super().__init__()
        self._shell_script = Path(os.path.dirname(__file__)) / 'java_run.sh'
        self._base_dir = base_dir
        

    def check(self, sample, generation: str, timeout: int = 60, debug: bool = True) -> bool:
        in_outs = json.loads(sample["input_output"])
        io_dict = {
            'inputs': in_outs["inputs"],
            'outputs': in_outs["outputs"],
            'fn_name': in_outs.get("func_name", None)
        }
        problem = LCBCorrectnessCriteria(
                    question_id=in_outs['question_id'],
                    question_title=in_outs['question_title'],
                    prompt=in_outs['prompt'],
                    fn_name=in_outs.get("fn_name", None),
                    test_inputs=in_outs["inputs"],
                    test_outputs=in_outs["outputs"],
                )
        result, metadata = self._check(problem, generation, timeout, debug)
        return result, metadata
    
    def _check(self, problem: LCBCorrectnessCriteria, generation: str, timeout: int = 60, debug: bool = True) -> bool:
        work_dir = self._base_dir / escape_file_path(problem.question_title)
        shutil.rmtree(work_dir, ignore_errors=True)

        result_file, test_file = self._prepare_workdir(generation, problem, work_dir)

        execution_timeout = (timeout + 1) * len(problem.test_inputs) + 30 # Compilation time < 30secs
        runner_params = [
            str(work_dir),
            str(test_file),
            str(result_file),
            str(timeout),
        ]

        # import pdb; pdb.set_trace()
        status = self._spawn_runner(work_dir, runner_params, execution_timeout)

        ret = []
        try:
            with open(result_file, 'r') as result_fo:
                results = json.load(result_fo)

                for i, result in enumerate(results):
                    if result['status'] != 'PASSED':
                        status = RunnerStatus.FAILED
                        ret.append(False)
                        failed_input = result["input"]
                        failed_output = result["output"]
                        failed_expected = result["expected-output"]
                        break
                    
                    ret.append(True)
        except:
            if debug:
                print("No results file")

        if status == RunnerStatus.OK:
            return ret, {}
            write_to_file(work_dir / 'passed', 'passed')
        if status == RunnerStatus.TIMEOUT:
            ret.append(-1)
            with open(work_dir / 'stderr.txt', 'r') as f:
                err = f.read()
            return ret ,{
                "error": err,
                "error_code": -1,
                "error_message": "Timeout",
            } 
        elif status == RunnerStatus.FAILED:
            return ret, {
                    "output": failed_output,
                    "expected": failed_expected,
                    "inputs": failed_input,
                    "error_code": -2,
                    "error_message": "Wrong Answer",
                }
        elif status == RunnerStatus.ERROR:
            ret.append(-1)
            with open(work_dir / "stderr.txt", "r") as f:
                err = f.read()
            return ret,{
                "error": err,
                "error_code": -1,
                "error_message": "Timeout",
            }  

        raise ValueError(f"Status has unhandled value: {status}")

    def _prepare_workdir(self, generation, problem, work_dir):
        src_dir = work_dir / 'src'
        os.makedirs(src_dir, exist_ok=True)
        test_file = work_dir / 'tests.json'
        result_file = work_dir / 'results.json'
        self._write_solution_code(generation, src_dir)
        self._write_tests(problem, test_file)
        return result_file, test_file

    def _spawn_runner(self, work_dir, params, timeout):
        status = RunnerStatus.OK
        with open(work_dir / 'stdout.txt', 'w') as out_handle:
            with open(work_dir / 'stderr.txt', 'w') as err_handle:
                process = subprocess.Popen([
                    'bash',
                    self._shell_script,
                    *params,
                ], stdout=out_handle, stderr=err_handle)
                try:
                    return_code = process.wait(timeout=timeout)

                    if return_code != 0:
                        write_to_file(work_dir / 'error', f'Error code {return_code}\n')
                        status = RunnerStatus.ERROR

                except subprocess.TimeoutExpired:
                    status = RunnerStatus.TIMEOUT
                    os.kill(process.pid, signal.SIGKILL)
                    with open(work_dir / 'stderr.txt', 'wt+') as f:
                        f.write('\n###################################################')
                        f.write('\n# EXPIRED!!!  EXPIRED!!! EXPIRED!!! EXPIRED!!!    #')
                        f.write('\n###################################################')
        return status

    def _write_tests(self, problem, input_file):
        io_dict = {
            'inputs': problem.test_inputs,
            'outputs': problem.test_outputs,
            'fn_name': problem.fn_name,
        }
        input_json = {
            'class': 'Solution',
            'input_output': json.dumps(io_dict),
            **problem.model_dump(mode='json')
        }
        write_to_file(input_file, json.dumps(input_json))

    def _write_solution_code(self, generation, src_dir):
        java_unit = JAVA_SOLUTION_TEMPLATE.format(generation=generation)
        write_to_file(src_dir / 'Solution.java', java_unit)
