import argparse
import ast
import json
import re
import shutil
from typing import Optional

from natsort import natsorted
import pandas as pd
from tp_lodge.motion_planning.dummy_motion_validator import PDDLProblem
from tp_lodge.motion_planning.remote_motion_validator import RemoteMotionValidator
import tqdm
from tp_lodge.motion_planning.motion_validator import MotionSimulationException, MotionValidator
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from pathlib import Path
from pddl.core import And

from llm_utils.openai_api.chat import Chat
from llm_utils.openai_api.text_message_content import TextMessageContent
from llm_utils.openai_api.user_message import UserMessage
from llm_utils.openai_api.message import Message
from llm_utils.openai_api.message_role import MessageRole

from llm_utils.prompt_generation.utils import replace_text
from llm_utils.textgen_api.textgen_api import TextGenApi, logger

import logging

from tp_lodge.utils.pddl_utils import get_effects_from_pred_change, get_valid_predicates
from tp_lodge.utils.python_parse_utils import get_python_code_from_text

logging.basicConfig(level=logging.INFO, format="%(message)s")


def get_prompt() -> str:
    return """
### Python Skills
```python
{function_stubs}
```

### Objects
{object_list}

### Current State
{current_state}

### Domain Knowledge
{domain_knowledge}

### User Instruction
{instruction}

Your task is to generate a sequence of python skills that can be used to solve the problem described in the user instruction.
Enclose the python skills in triple backticks. Only output one python code block.

### Output Format
<-- any other text that helps you to come to the correct answer. -->
```python
{output}
```
"""


def _get_domain(data_dir: Path):
    domain = PDDLDomain.from_json(json.loads((data_dir / "domain_skeleton.json").read_text()))
    return domain


def run_sample(
    textgen_api: TextGenApi,
    motion_validator: MotionValidator,
    instruction: str,
    problem_skeleton,
    function_stubs: str,
    domain_knowledge: str,
    domain_skeleton: PDDLDomain,
    last_result: Optional[dict],
    n_retries: int,
    last_chat: Optional[Chat] = None,
):
    if n_retries <= 0:
        return {
            "code": last_result["code"],
            "instruction": instruction,
            "effect": None,
            "n_retries": n_retries,
            "error": "Exceeded maximum number of retries",
        }, last_result["chat"]

    problem_skeleton = motion_validator.inject_init_predicates(domain=domain_skeleton, problem=problem_skeleton)

    prompt = get_prompt()
    prompt = replace_text(prompt, "{function_stubs}", function_stubs)
    prompt = replace_text(prompt, "{object_list}", problem_skeleton.get_objects_str())
    prompt = replace_text(prompt, "{current_state}", str(And(*get_valid_predicates(problem_skeleton.initial_state))))
    prompt = replace_text(prompt, "{domain_knowledge}", domain_knowledge)
    prompt = replace_text(prompt, "{instruction}", instruction)

    if last_chat is not None:
        chat = last_chat
        error_msg = "Your last plan was incorrect.\nError:\n%s" % last_result["error"]
        chat = chat.add_message(UserMessage([TextMessageContent(error_msg)]))
    else:
        if last_result is not None:
            prompt = (
                prompt
                + "\n\n"
                + "Your last plan was incorrect.\nPlan:\n%s\n\nError:\n%s" % (last_result["plan"], last_result["error"])
            )

        chat = Chat(
            messages=[
                Message(
                    role=MessageRole.SYSTEM,
                    content=[TextMessageContent("You are a helpful assistant.")],
                ),
                UserMessage([TextMessageContent(prompt)]),
            ]
        )

    response = textgen_api.do_call(chat=chat)

    chat = chat.add_message(response)

    text = response.content[0].text
    error = None
    code = None
    effect = None

    try:
        code_cells = get_python_code_from_text(text=text, max_cells=None)
        import ast

        cells = [ast.unparse(ast.parse(cell)) for cell in code_cells]
        if all([c == cells[0] for c in cells]):
            code = cells[0]
        elif len(code_cells) != 1:
            return run_sample(
                textgen_api=textgen_api,
                motion_validator=motion_validator,
                instruction=instruction,
                problem_skeleton=problem_skeleton,
                function_stubs=function_stubs,
                domain_knowledge=domain_knowledge,
                domain_skeleton=domain_skeleton,
                last_result=last_result,
                # last_chat=chat,
                n_retries=n_retries - 1,
            )
        code = code_cells[0]
    except Exception as e:
        return run_sample(
            textgen_api=textgen_api,
            motion_validator=motion_validator,
            instruction=instruction,
            problem_skeleton=problem_skeleton,
            function_stubs=function_stubs,
            domain_knowledge=domain_knowledge,
            domain_skeleton=domain_skeleton,
            last_result=last_result,
            # last_chat=chat,
            n_retries=n_retries - 1,
        )

    if error is None:
        state_before_execution = motion_validator.get_predicates_evaluation(
            domain=domain_skeleton, problem=problem_skeleton
        )

        for line in code.splitlines():
            try:
                if re.match(r"^\s*#", line) is not None:
                    continue
                if line.strip() == "":
                    continue

                motion_validator._run_motion(line.strip())
            except Exception as e:
                if isinstance(e, (ValueError, MotionSimulationException, AssertionError)):
                    if isinstance(e, MotionSimulationException):
                        message = e.message
                    elif isinstance(e, ValueError):
                        message = str(e)
                    elif isinstance(e, AssertionError):
                        message = str(e)
                    print("MotionSimulationException: %s. Retry" % str(message))
                    return run_sample(
                        textgen_api=textgen_api,
                        motion_validator=motion_validator,
                        instruction=instruction,
                        problem_skeleton=problem_skeleton,
                        function_stubs=function_stubs,
                        domain_knowledge=domain_knowledge,
                        domain_skeleton=domain_skeleton,
                        last_result={
                            "plan": code,
                            "error": "Failed executing %s: %s" % (line, str(message)),
                            "chat": chat,
                            "code": code,
                        },
                        # last_chat=chat,
                        n_retries=n_retries - 1,
                    )
                else:
                    message = str(e)
                error = "Exception during code execution: %s" % str(message)

    if error is None:
        state_after_execution = motion_validator.get_predicates_evaluation(
            domain=domain_skeleton, problem=problem_skeleton
        )

        effect = get_effects_from_pred_change(
            prior_predicates=state_before_execution,
            post_predicates=state_after_execution,
        )

    response = {
        "code": code,
        "instruction": instruction,
        "effect": str(effect),
        "n_retries": n_retries,
        "error": error,
    }
    return response, chat


def run_generation(engine: str, variant: str) -> None:
    textgen_api = TextGenApi.default(connection=engine)
    engine = textgen_api.connections.connections[0].model_dir
    sample_dir = Path(__file__).parent.parent / "envs/lamp"
    out_dir = Path(__file__).parent.parent / "results" / engine / f"llm_planning_baseline-{variant}" / "2025-05-09T07:36:28.608729"
    out_dir.mkdir(exist_ok=True, parents=True)

    assert len(textgen_api.connections.connections) == 1
    logger.info(f"Using engine: {textgen_api.connections.connections[0].model}")

    motion_validator = RemoteMotionValidator(ip="localhost", port=8800)

    function_stubs = (sample_dir / "function_stubs.py").read_text()
    domain_knowledge = (sample_dir / "domain_knowledge.md").read_text()
    domain_skeleton = PDDLDomain.loads((sample_dir / "domain-skeleton.json").read_text())
    problem_skeleton = PDDLProblem.loads((sample_dir / "problem-skeleton.json").read_text())

    if (out_dir / "response.json").exists():
        data = json.loads((out_dir / "response.json").read_text())
        if data["error"] is not None and "More than" in data["error"]:
            shutil.rmtree(out_dir)
        else:
            return

    instruction = (sample_dir / "instruction.txt").read_text().strip()

    response, chat = run_sample(
        textgen_api=textgen_api,
        motion_validator=motion_validator,
        instruction=instruction,
        problem_skeleton=problem_skeleton,
        function_stubs=function_stubs,
        domain_knowledge=domain_knowledge,
        domain_skeleton=domain_skeleton,
        n_retries=20 if variant == "retries" else 1,
        last_result=None,
    )
    print(f"Response: {response['error']}")
    (out_dir / "response.json").write_text(json.dumps(response, indent=2))
    (out_dir / "chat.json").write_text(json.dumps(chat.to_dict(), indent=2))
    (out_dir / "chat.txt").write_text(str(chat))


def run_evaluation(engine: str, task: str, variant: str):
    assert task in ["logistics", "household"]

    textgen_api = TextGenApi.default(connection=engine)
    engine = textgen_api.connections.connections[0].model_dir
    out_dir = Path(__file__).parent.parent / "results" / task / engine / f"llm_planning_baseline-{variant}"
    out_dir.mkdir(exist_ok=True, parents=True)

    assert len(textgen_api.connections.connections) == 1
    logger.info(f"Using engine: {textgen_api.connections.connections[0].model}")
    eval_data = []
    for sample_dir in tqdm.tqdm(natsorted(out_dir.iterdir())):
        print("-" * 100)
        print(f"Sample dir: {sample_dir.name}")
        assert (sample_dir / "response.json").exists(), sample_dir
        data = json.loads((sample_dir / "response.json").read_text())
        instruction = data["instruction"]
        sample_eval_data = {
            "sample_dir": sample_dir.name,
            "n_code": None,
        }
        if "code" in data:
            sample_eval_data["n_code"] = len(ast.unparse(ast.parse(data["code"])).splitlines())

        if "correct" in data:
            print(f"Already evaluated as {data['correct']}")
            if not data["correct"]:
                print(data["error"])
            correct = data["correct"]

        elif data["error"] is not None:
            print(f"Error: {data['error']}")
            correct = False

        else:
            print(instruction)
            print(f"Code: \n{data['code']}\n")
            print(f"Effect: {data['effect']}")
            print("")
            while True:
                correct_str = input("Is the effect correct? (y/n): ")
                if correct_str == "y":
                    correct = True
                    break
                elif correct_str == "n":
                    correct = False
                    break
            data["correct"] = correct
            (sample_dir / "response.json").write_text(json.dumps(data, indent=2))

        sample_eval_data["correct"] = correct
        eval_data.append(sample_eval_data)

    df = pd.DataFrame(eval_data)
    print(df)
    print("Count correct: ", df["correct"].sum())
    print("total_count:", len(df))
    print("Avg plan length: ", df[df["correct"]]["n_code"].mean())
    df.to_csv("llm-baseline.csv")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Construct action models using LLMs.")
    parser.add_argument("--engine", type=str, help="Specify the LLM engine to use.", required=True)
    parser.add_argument(
        "--variant",
        type=str,
        choices=["no-retries", "retries"],
        help="Specify the variant ('no-retries', 'retries').",
        required=True,
    )
    parser.add_argument("--eval", action="store_true", help="Run evaluation instead of generation.")
    args = parser.parse_args()

    if args.eval:
        run_evaluation(engine=args.engine, task=args.task, variant=args.variant)
    else:
        run_generation(engine=args.engine, variant=args.variant)
