import argparse
import asyncio
import datetime
from functools import partial
import json
import math
import os
from pathlib import Path
from string import Template
import subprocess
import sys
from typing import Callable, Protocol

import ale_bench
import ale_bench.result
import ale_bench.session
import openhands.agenthub  # noqa F401 (we import this to get the agents registered)
from openhands.controller.agent import Agent
from openhands.controller.replay import ReplayManager
from openhands.controller.state.state import State
from openhands.core.config import (
    AppConfig,
    get_parser,
    setup_config_from_args,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.loop import run_agent_until_done
from openhands.core.schema import AgentState
from openhands.core.setup import (
    create_agent,
    create_controller,
    create_memory,
    create_runtime,
    generate_sid,
    initialize_repository_for_runtime,
)
from openhands.events import EventSource, EventStreamSubscriber
from openhands.events.action import AgentFinishAction, MessageAction, NullAction
from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.io import read_input, read_task
from openhands.mcp import fetch_mcp_tools_from_config
from openhands.memory.memory import Memory
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync


class FakeUserResponseFunc(Protocol):
    def __call__(
        self,
        state: State,
        encapsulate_solution: bool = False,
        try_parse: Callable[[Action | None], str] | None = None,
    ) -> str: ...


CODE_LANGUAGE_STRING = {
    "cpp20": "C++20 (gcc 12.2.0)",
    "python": "Python (CPython 3.11.4)",
    "rust": "Rust (rustc 1.70.0)",
}
CODE_LANGUAGE_LIBRARIES = {
    "cpp20": """- AC Library@1.5.1
- Boost@1.82.0
- GMP@6.2.1
- Eigen@3.4.0-2ubuntu2""",
    "python": """- numpy==1.24.1
- scipy==1.10.1
- networkx==3.0
- sympy==1.11.1
- sortedcontainers==2.4.0
- more-itertools==9.0.0
- shapely==2.0.0
- bitarray==2.6.2
- PuLP==2.7.0
- mpmath==1.2.1
- pandas==1.5.2
- z3-solver==4.12.1.0
- scikit-learn==1.2.0
- ortools==9.5.2237
- ac-library-python
- setuptools==66.0.0
- cppyy==2.4.1
- torch==1.13.1
- polars==0.15.15
- lightgbm==3.3.1
- gmpy2==2.1.5
- numba==0.57.0""",
    "rust": """- ac-library-rs@=0.1.1
- once_cell@=1.18.0
- static_assertions@=1.1.0
- varisat@=0.2.2
- memoise@=0.3.2
- argio@=0.2.0
- bitvec@=1.0.1
- counter@=0.5.7
- hashbag@=0.1.11
- pathfinding@=4.3.0
- recur-fn@=2.2.0
- indexing@=0.4.1
- amplify@=3.14.2
- amplify_derive@=2.11.3
- amplify_num@=0.4.1
- easy-ext@=1.0.1
- multimap@=0.9.0
- btreemultimap@=0.1.1
- bstr@=1.6.0
- az@=1.2.1
- glidesort@=0.1.2
- tap@=1.0.1
- omniswap@=0.1.0
- multiversion@=0.7.2
- num@=0.4.1
- num-bigint@=0.4.3
- num-complex@=0.4.3
- num-integer@=0.1.45
- num-iter@=0.1.43
- num-rational@=0.4.1
- num-traits@=0.2.15
- num-derive@=0.4.0
- ndarray@=0.15.6
- nalgebra@=0.32.3
- alga@=0.9.3
- libm@=0.2.7
- rand@=0.8.5
- getrandom@=0.2.10
- rand_chacha@=0.3.1
- rand_core@=0.6.4
- rand_hc@=0.3.2
- rand_pcg@=0.3.1
- rand_distr@=0.4.3
- petgraph@=0.6.3
- indexmap@=2.0.0
- regex@=1.9.1
- lazy_static@=1.4.0
- ordered-float@=3.7.0
- ascii@=1.1.0
- permutohedron@=0.2.4
- superslice@=1.0.0
- itertools@=0.11.0
- itertools-num@=0.1.3
- maplit@=1.0.2
- either@=1.8.1
- im-rc@=15.1.0
- fixedbitset@=0.4.2
- bitset-fixed@=0.1.0
- proconio@=0.4.5
- text_io@=0.1.12
- rustc-hash@=1.1.0
- smallvec@=1.11.0""",
}
CODE_LANGUAGE_COMPILE_COMMAND = {
    "cpp20": (
        "g++-12 -std=gnu++20 -O2 -DONLINE_JUDGE -DATCODER -Wall -Wextra -mtune=native "
        "-march=native -fconstexpr-depth=2147483647 -fconstexpr-loop-limit=2147483647 "
        "-fconstexpr-ops-limit=2147483647 -I/opt/ac-library -I/opt/boost/gcc/include "
        "-L/opt/boost/gcc/lib -o solution.out solution.cpp -lgmpxx -lgmp -I/usr/include/eigen3"
    ),
    "python": "python3.11 -m py_compile solution.py; python3.11 solution.py ONLINE_JUDGE 2> /dev/null",
    "rust": "RUST_BACKTRACE=0 cargo build --release --quiet --offline",
}
CODE_LANGUAGE_RUN_COMMAND = {
    "cpp20": "./solution.out",
    "python": "python3.11 solution.py",
    "rust": "./target/release/solution",
}
FILE_EXTENSION = {"cpp20": "cpp", "python": "py", "rust": "rs"}


INITIAL_MESSAGE = Template(
    "You are a world-class algorithm engineer, and you are very good at programming. "
    "Now, you are participating in a programming contest. "
    "You are asked to solve a heuristic problem, known as an NP-hard problem. "
    "The duration of the contest is ${duration_in_hours} hours.\n\n"
    "There is a problem statement at the end of this message. "
    "First, please analyze the problem statement. "
    "Please think about the essential points of the problem and possible algorithms to get higher rank in the contest.\n\n"
    "Next, please implement your solution in **${language}**. "
    "Your solution code should be written in the **`/workspace/solution.${file_extension}`**. "
    "You can use external libraries as follows:\n${libraries}\n\n"
    "You can run your code with the following commands if you are at the `/workspace` directory:\n```bash\n# Compile\n${compile_command}\n# Run\n${run_command}\n```\n\n"
    "After you implement your solution, **you must ask user (not exit) to evaluate your temporary solution** whether your solution is effective or not. "
    "The user will read your code from `/workspace/solution.${file_extension}` file. "
    "If the user can not find your solution file, the user will ask you to provide your code again. "
    "Please make sure that your code is exist in proper way.\n\n"
    "The user can run your solution code with 50 public cases and give you feedback including the score if the code is accepted or the error message if the code is not accepted. "
    "**Even if your solution get accepted, you must refine your solution or try another approach to get a better score.**\n\n"
    "You can use the internet to search an algorithm or a library to solve the problem. "
    "But you must implement your solution by yourself (no plagiarism) and do not see any contents related to this problem. "
    "More specifically, you must not see web pages that contain the word \"AtCoder\", \"AHC\", \"${problem_name}\" or contents considered to be directly related to this problem because some people were already solved this problem. "
    "We will check web pages you visited in order to detect the plagiarism and cheating. "
    "If we detect that you are violating this rule, you will be disqualified from the contest.\n\n"
    "[Problem statement]\nExecution time limit: ${time_limit} sec / Memory limit: ${memory_limit} MB\n${problem_statement}"
)
FEEDBACK_MESSAGE = Template(
    "${feedback}\n\n"
    "Based on the above feedback, please consider the ways to improve your solution. "
    "Firstly, please analyze this given feedback and list what insights can be gained from it. "
    "Then, based on the insights, please refine your code to achieve better performance. "
    "It can be a simple bug fix, the introduction of a new algorithm, or any degree of change from minor to major. "
    "If you think you did your best on the task, please finish the interaction.\n"
    "Again, you should implement your entire solution in the `/workspace/solution.${file_extension}` file."
)


LOG_FILE = None


async def run_controller(
    config: AppConfig,
    initial_user_action: Action,
    sid: str | None = None,
    runtime: Runtime | None = None,
    agent: Agent | None = None,
    exit_on_message: bool = False,
    fake_user_response_fn: FakeUserResponseFunc | None = None,
    headless_mode: bool = True,
    memory: Memory | None = None,
) -> State | None:
    sid = sid or generate_sid(config)

    if agent is None:
        agent = create_agent(config)
        mcp_tools = await fetch_mcp_tools_from_config(config.mcp)
        agent.set_mcp_tools(mcp_tools)

    # when the runtime is created, it will be connected and clone the selected repository
    repo_directory = None
    if runtime is None:
        runtime = create_runtime(
            config,
            sid=sid,
            headless_mode=headless_mode,
            agent=agent,
        )
        # Connect to the runtime
        call_async_from_sync(runtime.connect)

        # Initialize repository if needed
        if config.sandbox.selected_repo:
            repo_directory = initialize_repository_for_runtime(
                runtime,
                selected_repository=config.sandbox.selected_repo,
            )

    event_stream = runtime.event_stream

    # when memory is created, it will load the microagents from the selected repository
    if memory is None:
        memory = create_memory(
            runtime=runtime,
            event_stream=event_stream,
            sid=sid,
            selected_repository=config.sandbox.selected_repo,
            repo_directory=repo_directory,
        )

    replay_events: list[Event] | None = None
    if config.replay_trajectory_path:
        logger.info('Trajectory replay is enabled')
        assert isinstance(initial_user_action, NullAction)
        replay_events, initial_user_action = load_replay_log(
            config.replay_trajectory_path
        )

    controller, initial_state = create_controller(
        agent, runtime, config, replay_events=replay_events
    )

    assert isinstance(
        initial_user_action, Action
    ), f'initial user actions must be an Action, got {type(initial_user_action)}'
    logger.debug(
        f'Agent Controller Initialized: Running agent {agent.name}, model '
        f'{agent.llm.config.model}, with actions: {initial_user_action}'
    )

    # start event is a MessageAction with the task, either resumed or new
    if initial_state is not None and initial_state.last_error:
        # we're resuming the previous session
        event_stream.add_event(
            MessageAction(
                content=(
                    "Let's get back on track. If you experienced errors before, do "
                    'NOT resume your task. Ask me about it.'
                ),
            ),
            EventSource.USER,
        )
    else:
        # init with the provided actions
        event_stream.add_event(initial_user_action, EventSource.USER)

    def on_event(event: Event) -> None:
        if isinstance(event, AgentStateChangedObservation):
            if event.agent_state == AgentState.AWAITING_USER_INPUT:
                if exit_on_message:
                    message = '/exit'
                elif fake_user_response_fn is None:
                    message = read_input(config.cli_multiline_input)
                else:
                    message = fake_user_response_fn(controller.get_state())
                action = MessageAction(content=message)
                event_stream.add_event(action, EventSource.USER)

    event_stream.subscribe(EventStreamSubscriber.MAIN, on_event, sid)

    end_states = [
        AgentState.FINISHED,
        AgentState.REJECTED,
        AgentState.ERROR,
        AgentState.PAUSED,
        AgentState.STOPPED,
    ]

    try:
        await run_agent_until_done(controller, runtime, memory, end_states)
    except Exception as e:
        logger.error(f'Exception in main loop: {e}')

    # save session when we're about to close
    if config.file_store is not None and config.file_store != 'memory':
        end_state = controller.get_state()
        # NOTE: the saved state does not include delegates events
        end_state.save_to_session(
            event_stream.sid, event_stream.file_store, event_stream.user_id
        )

    await controller.close(set_stop_state=False)

    state = controller.get_state()

    # save trajectories if applicable
    if config.save_trajectory_path is not None:
        # if save_trajectory_path is a folder, use session id as file name
        if os.path.isdir(config.save_trajectory_path):
            file_path = os.path.join(config.save_trajectory_path, sid + '.json')
        else:
            file_path = config.save_trajectory_path
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        histories = controller.get_trajectory(config.save_screenshots_in_trajectory)
        with open(file_path, 'w') as f:  # noqa: ASYNC101
            json.dump(histories, f, indent=4)

    return state


def replace_details_in_case_result(result: ale_bench.result.Result) -> ale_bench.result.Result:
    processed_case_results = []
    need_details = True
    for case_result in result.case_results:
        if need_details and case_result.judge_result == result.overall_judge_result:
            processed_case_results.append(ale_bench.result.CaseResult(
                input_str=case_result.input_str,
                output_str=case_result.output_str,
                error_str=case_result.error_str,
                judge_result=case_result.judge_result,
                message=case_result.message,
                absolute_score=case_result.absolute_score,
                relative_score=case_result.relative_score,
                local_visualization=case_result.local_visualization,
                execution_time=case_result.execution_time,
                memory_usage=case_result.memory_usage,
            ))
            need_details = False
        else:
            processed_case_results.append(ale_bench.result.CaseResult(
                input_str=None,
                output_str=None,
                error_str=None,
                judge_result=case_result.judge_result,
                message=case_result.message,
                absolute_score=case_result.absolute_score,
                relative_score=case_result.relative_score,
                local_visualization=None,
                execution_time=case_result.execution_time,
                memory_usage=case_result.memory_usage,
            ))
    return ale_bench.result.Result(
        allow_score_non_ac=result.allow_score_non_ac,
        resource_usage=result.resource_usage,
        case_results=processed_case_results,
    )


def case_result_feedback(case_idx: int, case_result: ale_bench.result.CaseResult) -> str:
    return f"""- Case {case_idx}:
    Absolute score: {case_result.absolute_score}
    Execution time: {case_result.execution_time:.3f} sec
    Memory usage: {case_result.memory_usage // 1024 // 1024} MB
    Standard error: \"{case_result.error_str}\"
    Message: \"{case_result.message}\""""


def create_feedback_message(public_result: ale_bench.result.Result, code_language: str, vis_port: int) -> str:
    feedback = f"[Public test result]\nOverall judge result: {public_result.overall_judge_result.value}\n"
    if public_result.overall_judge_result == ale_bench.result.JudgeResult.ACCEPTED:
        print(f"Accepted (score: {public_result.overall_absolute_score})", file=LOG_FILE.open("a"))
        feedback += f"Overall absolute score: {public_result.overall_absolute_score}\n"
        feedback += "\n".join([
            f"- Case {i}: {case_result.absolute_score}"
            for i, case_result in enumerate(public_result.case_results, 1)
        ])
        feedback += (
            f"\nYou can see your solution in `http://172.17.0.1:{vis_port}?lang=en` to visualize your solution. "
            f"If you want to visualize your solution in the local web UI, please access the link above. "
            f"Then, please detect the Input field in the web page (recognize its \"bid\") and paste ```txt\n{public_result.case_results[0].input_str}\n``` to the input field. "
            f"Also, please detect the Output field in the web page (recognize its \"bid\") and paste ```txt\n{public_result.case_results[0].output_str}\n``` to the output field in the page. "
            "If you do this correctly, you can see the visualization of your solution under the Input and Output fields in the web page. "
            "This visualization will help you to understand the problem and the features of your solution."
        )
    else:
        print(f"Not Accepted ({public_result.overall_judge_result.value})", file=LOG_FILE.open("a"))
        selected_case_idx = 0
        for idx, case_result in enumerate(public_result.case_results):
            if case_result.judge_result == public_result.overall_judge_result:
                selected_case_idx = idx
                break
        feedback += case_result_feedback(selected_case_idx + 1, public_result.case_results[selected_case_idx])
    return FEEDBACK_MESSAGE.substitute(feedback=feedback, file_extension=FILE_EXTENSION[code_language])


def auto_continue_response(
    state: State,
    encapsulate_solution: bool = False,
    try_parse: Callable[[Action | None], str] | None = None,
    ale_bench_session: ale_bench.session.Session | None = None,  # Additional
    code_language: str | None = None,  # Additional
    solution_file_path: Path | None = None,  # Additional
    code_history: list | None = None,  # Additional
    finished_at: datetime.datetime | None = None,  # Additional
) -> str:
    # Validate inputs
    if ale_bench_session is None:
        raise ValueError("ALE Bench session is not provided.")
    if code_language is None:
        raise ValueError("Code language is not provided.")
    if solution_file_path is None:
        raise ValueError("Solution file path is not provided.")
    if code_history is None:
        raise ValueError("Code history is not provided.")
    if finished_at is None:
        raise ValueError("Finished at time is not provided.")
    # Get the last message from the state and try to parse the code
    message = state.get_last_agent_message()  # MessageAction
    # Change owner of the workspace directory (caution: we need `sudo`)
    subprocess.run(["sudo", "chown", "-R", f"{os.getuid()}:{os.getgid()}", str(exp_dir / "workspace")])
    if solution_file_path.is_file():
        code = solution_file_path.read_text()
    else:
        remaining_time = finished_at - datetime.datetime.now(tz=datetime.timezone.utc)
        remaining_time_in_seconds = max(0, remaining_time.total_seconds())
        remaining_time_str = f"{round(remaining_time_in_seconds / 60, 1)} minutes"
        if remaining_time_in_seconds > 3600:
            remaining_time_str = f"{round(remaining_time_in_seconds / 3600, 1)} hours"
        return (
            f"The solution file (named `/workspace/solution.{FILE_EXTENSION[code_language]}`) is not found. "
            f"Please implement your solution in {CODE_LANGUAGE_STRING[code_language]} and save it in the solution file.\n"
            f"You have {remaining_time_str} left to finish the task."
        )

    public_result = replace_details_in_case_result(
        ale_bench_session.public_eval(code=code, code_language=code_language)
    )
    code_history.append((public_result, code_language, code, message.content))

    return create_feedback_message(
        public_result,
        code_language=code_language,
        vis_port=ale_bench_session.visualization_server_port,
    )


def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
    """
    Load trajectory from given path, serialize it to a list of events, and return
    two things:
    1) A list of events except the first action
    2) First action (user message, a.k.a. initial task)
    """
    try:
        path = Path(trajectory_path).resolve()

        if not path.exists():
            raise ValueError(f'Trajectory file not found: {path}')

        if not path.is_file():
            raise ValueError(f'Trajectory path is a directory, not a file: {path}')

        with open(path, 'r', encoding='utf-8') as file:
            events = ReplayManager.get_replay_events(json.load(file))
            assert isinstance(events[0], MessageAction)
            return events[1:], events[0]
    except json.JSONDecodeError as e:
        raise ValueError(f'Invalid JSON format in {trajectory_path}: {e}')


if __name__ == '__main__':
    # Add argument
    parser = argparse.ArgumentParser()
    parser.add_argument("--problem-id", type=str, required=True)
    parser.add_argument("--code-language", type=str, required=True, choices=["cpp20", "python", "rust"])
    parser.add_argument("--max-time-in-hours", type=float, required=True)
    # Parse additional arguments
    ale_args, openhands_args = parser.parse_known_args()
    # Parse OpenHands arguments (openhands.core.configparse_arguments)
    args = get_parser().parse_args(openhands_args)
    if args.version:
        print(f'OpenHands version: {openhands.__version__}')
        sys.exit(0)

    # Set up the experiment directory and variables
    exp_dir = Path(__file__).resolve().parent / "experiments" / args.name
    exp_dir.mkdir(parents=True, exist_ok=True)
    LOG_FILE = exp_dir / f"log_{ale_args.problem_id}.txt"
    public_eval_code_history = []

    # Set up the OpenHands
    config: AppConfig = setup_config_from_args(args)
    config.llms["llm"].log_completions = True
    config.llms["llm"].log_completions_folder = str(exp_dir / "llm_completions")
    config.file_store = "local"
    config.file_store_path = str(exp_dir / "file_store")
    config.save_trajectory_path = str(exp_dir / "trajectory.json")
    config.save_screenshots_in_trajectory = True
    config.workspace_base = str(exp_dir / "workspace")
    config.workspace_mount_path = str(exp_dir / "workspace")
    config.sandbox.base_container_image = f"yimjk/ale-bench:{ale_args.code_language}-202301"

    solution_file_path = exp_dir / "workspace" / f"solution.{FILE_EXTENSION[ale_args.code_language]}"

    # Start ALE-Bench session
    ale_bench_session = ale_bench.start(ale_args.problem_id, False, num_workers=13, run_visualization_server=True)

    max_time_in_seconds = math.floor(min(
        datetime.timedelta(hours=ale_args.max_time_in_hours), ale_bench_session.problem.metadata.duration
    ).total_seconds())
    max_time_in_seconds -= 60 * 2  # Reserve 2 minutes for margin

    task_str = INITIAL_MESSAGE.substitute(
        duration_in_hours=round(min(ale_args.max_time_in_hours, ale_bench_session.problem.metadata.duration.total_seconds() / 3600), 1),
        language=CODE_LANGUAGE_STRING[ale_args.code_language],
        file_extension=FILE_EXTENSION[ale_args.code_language],
        libraries=CODE_LANGUAGE_LIBRARIES[ale_args.code_language],
        compile_command=CODE_LANGUAGE_COMPILE_COMMAND[ale_args.code_language],
        run_command=CODE_LANGUAGE_RUN_COMMAND[ale_args.code_language],
        problem_name=ale_bench_session.problem.metadata.title,
        time_limit=ale_bench_session.problem.constraints.time_limit,
        memory_limit=ale_bench_session.problem.constraints.memory_limit // 1024 // 1024,
        problem_statement=ale_bench_session.problem.statement,
    )
    initial_user_action = MessageAction(content=task_str)

    initial_user_action: Action = NullAction()
    if config.replay_trajectory_path:
        if task_str:
            raise ValueError(
                'User-specified task is not supported under trajectory replay mode'
            )
    else:
        if not task_str:
            raise ValueError('No task provided. Please specify a task through -t, -f.')

        # Create actual initial user action
        initial_user_action = MessageAction(content=task_str)

    # Set session name
    session_name = args.name
    sid = generate_sid(config, session_name)

    finished_at = datetime.datetime.now(tz=datetime.timezone.utc)
    finished_at += datetime.timedelta(seconds=max_time_in_seconds)

    try:
        asyncio.run(asyncio.wait_for(
            run_controller(
                config=config,
                initial_user_action=initial_user_action,
                sid=sid,
                fake_user_response_fn=None
                if args.no_auto_continue
                else partial(
                    auto_continue_response,
                    ale_bench_session=ale_bench_session,
                    code_language=ale_args.code_language,
                    solution_file_path=solution_file_path,
                    code_history=public_eval_code_history,
                    finished_at=finished_at,
                ),
            ),
            timeout=max_time_in_seconds,
        ))
    except asyncio.TimeoutError:
        print("Time is up.", file=LOG_FILE.open("a"))
    except Exception as e:
        print(f"Error: {e}", file=LOG_FILE.open("a"))

    # Change owner of the workspace directory (caution: we need `sudo`)
    subprocess.run(["sudo", "chown", "-R", f"{os.getuid()}:{os.getgid()}", str(exp_dir / "workspace")])

    solution_code = ""
    if solution_file_path.is_file():
        submission_code = solution_file_path.read_text()
    else:
        print(f"Solution file not found: {solution_file_path}", file=LOG_FILE.open("a"))

    try:
        private_result, rank, performance = ale_bench_session.private_eval(
            code=submission_code, code_language=ale_args.code_language,
        )
        print(f"Rank: {rank}, Performance: {performance}", file=LOG_FILE.open("a"))
        json.dump({
            "problem_id": ale_args.problem_id,
            "rank": rank,
            "performance": performance,
            "solution_code": submission_code,
            "private_result": private_result.model_dump(),
        }, (exp_dir / f"private_result_{ale_args.problem_id}.json").open("w"))
    except:
        print(f"[{ale_args.problem_id}] Error: Private evaluation failed. ({ale_args.problem_id})", file=LOG_FILE.open("a"))
    ale_bench_session.save(exp_dir / f"session_{ale_args.problem_id}.json")
    json.dump([
        [r.model_dump(), cl, c, m] for r, cl, c, m in public_eval_code_history
    ], (exp_dir / f"codes_history_{ale_args.problem_id}.json").open("w"))
