import argparse
from pddl.parser.problem import ProblemParser
from python_utils.string_utils import remove_docstrings
import ast
import json
import re
import shutil
from typing import Optional
from uuid import uuid4

from dotenv import load_dotenv
from natsort import natsorted
import pandas as pd
from tqdm import trange
import tqdm
from tp_lodge.motion_planning.motion_validator import MotionSimulationException
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
import copy
from pathlib import Path
from pddl.core import And

from llm_utils.openai_api.chat import Chat
from llm_utils.openai_api.text_message_content import TextMessageContent
from llm_utils.openai_api.user_message import UserMessage
from llm_utils.openai_api.message import Message
from llm_utils.openai_api.message_role import MessageRole

from llm_utils.prompt_generation.utils import replace_text
from llm_utils.textgen_api.textgen_api import TextGenApi, logger

from demos.ipc.src.household.household_motion_validator import HouseholdMotionValidator
from demos.ipc.src.logistics.logistics_motion_validator import LogisticsMotionValidator
from demos.ipc.src.planning_benchmark_sample_generator import PlanningBenchmarkSampleGenerator
from tp_lodge.motion_planning.local_motion_validator import LocalMotionValidator
import logging

from tp_lodge.utils.pddl_utils import get_effects_from_pred_change, get_valid_predicates, is_predicates_subset
from tp_lodge.utils.pddl_parse_utils import parse_formula
from tp_lodge.utils.python_parse_utils import get_python_code_from_text

logging.basicConfig(level=logging.INFO, format="%(message)s")


def get_prompt() -> str:
    return """
### Python Skills
```python
{function_stubs}
```

### Objects
{object_list}

### Current State
{current_state}

### Domain Knowledge
{domain_knowledge}

### User Instruction
{instruction}

Your task is to generate a sequence of python skills that can be used to solve the problem described in the user instruction.
Enclose the python skills in triple backticks. Only output one python code block.

### Output Format
<-- any other text that helps you to come to the correct answer. -->
```python
{output}
```
"""


def _get_domain(data_dir: Path):
    domain = PDDLDomain.from_json(json.loads((data_dir / "domain_skeleton.json").read_text()))
    return domain


def run_sample(
    textgen_api: TextGenApi,
    motion_validator: LocalMotionValidator,
    generator: PlanningBenchmarkSampleGenerator,
    i: int,
    instruction: str,
    problem_skeleton,
    function_stubs: str,
    domain_knowledge: str,
    domain_skeleton: PDDLDomain,
    last_result: Optional[dict],
    n_tries: int,
    last_chat: Optional[Chat] = None,
):
    instruction, env_state, problem_skeleton = generator.generate(i)
    assert isinstance(motion_validator, LocalMotionValidator)
    assert hasattr(motion_validator.env, "_init_env_state")
    motion_validator.env._init_env_state = copy.deepcopy(env_state)
    motion_validator.env._set_state(new_state=copy.deepcopy(env_state))

    if n_tries <= 0:
        return {
            "code": last_result["code"],
            "instruction": instruction,
            "effect": None,
            "n_tries": n_tries,
            "error": "Exceeded maximum number of tries",
        }, last_result["chat"]

    problem_skeleton = motion_validator.inject_init_predicates(domain=domain_skeleton, problem=problem_skeleton)

    prompt = get_prompt()
    prompt = replace_text(prompt, "{function_stubs}", function_stubs)
    prompt = replace_text(prompt, "{object_list}", problem_skeleton.get_objects_str())
    prompt = replace_text(prompt, "{current_state}", str(And(*get_valid_predicates(problem_skeleton.initial_state))))
    prompt = replace_text(prompt, "{domain_knowledge}", domain_knowledge)
    prompt = replace_text(prompt, "{instruction}", instruction)

    if last_chat is not None:
        chat = last_chat
        error_msg = "Your last plan was incorrect.\nError:\n%s" % (last_result.get("error", "Unknown error") if last_result is not None else "Unknown error")
        chat = chat.add_message(UserMessage([TextMessageContent(error_msg)]))
    else:
        if last_result is not None:
            prompt = (
                prompt
                + "\n\n"
                + "Your last plan was incorrect.\nPlan:\n%s\n\nError:\n%s" % (last_result["plan"], last_result["error"])
            )

        chat = Chat(
            messages=[
                Message(
                    role=MessageRole.SYSTEM,
                    content=[TextMessageContent("You are a helpful assistant.")],
                ),
                UserMessage([TextMessageContent(prompt)]),
            ]
        )

    response = textgen_api.do_call(chat=chat)

    chat = chat.add_message(response)

    text = response.content[0].text
    error = None
    code = None
    effect = None

    try:
        code_cells = get_python_code_from_text(text=text, max_cells=None)
        import ast

        cells = [ast.unparse(ast.parse(cell)) for cell in code_cells]
        if all([c == cells[0] for c in cells]):
            code = cells[0]
        elif len(code_cells) != 1:
            return run_sample(
                textgen_api=textgen_api,
                motion_validator=motion_validator,
                instruction=instruction,
                generator=generator,
                i=i,
                problem_skeleton=problem_skeleton,
                function_stubs=function_stubs,
                domain_knowledge=domain_knowledge,
                domain_skeleton=domain_skeleton,
                last_result=last_result,
                last_chat=chat,
                n_tries=n_tries,
            )
        code = code_cells[0]
    except Exception as e:
        return run_sample(
            textgen_api=textgen_api,
            motion_validator=motion_validator,
            instruction=instruction,
            problem_skeleton=problem_skeleton,
            function_stubs=function_stubs,
            generator=generator,
            i=i,
            domain_knowledge=domain_knowledge,
            domain_skeleton=domain_skeleton,
            last_result=last_result,
            last_chat=chat,
            n_tries=n_tries,
        )

    if error is None:
        state_before_execution = motion_validator.get_predicates_evaluation(
            domain=domain_skeleton, problem=problem_skeleton
        )

        for line in code.splitlines():
            try:
                if re.match(r"^\s*#", line) is not None:
                    continue
                if line.strip() == "":
                    continue

                motion_validator._run_motion(line.strip())
            except Exception as e:
                if isinstance(e, (ValueError, MotionSimulationException, AssertionError)):
                    if isinstance(e, MotionSimulationException):
                        message = e.message
                    elif isinstance(e, ValueError):
                        message = str(e)
                    elif isinstance(e, AssertionError):
                        message = str(e)
                    print("MotionSimulationException: %s. Retry" % str(message))
                    return run_sample(
                        textgen_api=textgen_api,
                        motion_validator=motion_validator,
                        instruction=instruction,
                        generator=generator,
                        i=i,
                        problem_skeleton=problem_skeleton,
                        function_stubs=function_stubs,
                        domain_knowledge=domain_knowledge,
                        domain_skeleton=domain_skeleton,
                        last_result={
                            "plan": code,
                            "error": "Failed executing %s: %s" % (line, str(message)),
                            "chat": chat,
                            "code": code,
                        },
                        last_chat=chat,
                        n_tries=n_tries - 1,
                    )
                else:
                    message = str(e)
                error = "Exception during code execution: %s" % str(message)

    if error is None:
        state_after_execution = motion_validator.get_predicates_evaluation(
            domain=domain_skeleton, problem=problem_skeleton
        )

        effect = get_effects_from_pred_change(
            prior_predicates=state_before_execution,
            post_predicates=state_after_execution,
        )

    response = {
        "code": code,
        "instruction": instruction,
        "effect": str(effect),
        "n_tries": n_tries,
        "error": error,
    }
    return response, chat


def _get_dir(args):
    suffix = f"{args.tries}-tries"

    if args.no_docstrings:
        suffix += "-no-docstrings"
    else:
        suffix += "-with-docstrings"

    return (
        Path(__file__).parent.parent
        / "results"
        / args.task
        / args.engine
        / "llm_planning_baseline"
        / suffix
    )


def run_generation(args) -> None:
    assert args.task in ["logistics", "household"]

    textgen_api = TextGenApi.default(connection=args.engine)
    args.engine = textgen_api.connections.connections[0].model_dir
    data_dir = Path(__file__).parent.parent / "data" / args.task
    out_dir = _get_dir(args)
    out_dir.mkdir(exist_ok=True, parents=True)

    assert len(textgen_api.connections.connections) == 1
    logger.info(f"Using engine: {textgen_api.connections.connections[0].model}")

    if args.task == "logistics":
        motion_validator = LogisticsMotionValidator(hide_env_feedback=False)
    elif args.task == "household":
        motion_validator = HouseholdMotionValidator(hide_env_feedback=False)

    function_stubs = (data_dir / "function_stubs.py").read_text()
    if args.no_docstrings:
        function_stubs = remove_docstrings(function_stubs)
    domain_knowledge = (data_dir / "domain_knowledge.md").read_text()
    domain_skeleton = _get_domain(data_dir=data_dir)

    generator = PlanningBenchmarkSampleGenerator(args.task, use_hl_types=True)

    pbar = trange(len(generator))
    for i in pbar:
        s_out_dir = out_dir / f"task-{i}"
        s_out_dir.mkdir(exist_ok=True)

        if (s_out_dir / "response.json").exists():
            data = json.loads((s_out_dir / "response.json").read_text())
            if data["error"] is not None and "More than" in data["error"]:
                shutil.rmtree(s_out_dir)
            else:
                continue

        instruction, env_state, problem_skeleton = generator.generate(i)
        pbar.set_description_str(f"Task {i}: {instruction}")
        response, chat = run_sample(
            textgen_api=textgen_api,
            motion_validator=motion_validator,
            instruction=instruction,
            problem_skeleton=problem_skeleton,
            function_stubs=function_stubs,
            domain_knowledge=domain_knowledge,
            domain_skeleton=domain_skeleton,
            generator=generator,
            i=i,
            n_tries=args.tries,
            last_result=None,
        )
        print(f"Response: {response['error']}")
        (s_out_dir / "response.json").write_text(json.dumps(response, indent=2))
        (s_out_dir / "chat.json").write_text(json.dumps(chat.to_dict(), indent=2))
        (s_out_dir / "chat.txt").write_text(str(chat))


def run_evaluation(args):
    assert args.task in ["logistics", "household"]

    textgen_api = TextGenApi.default(connection=args.engine)
    args.engine = textgen_api.connections.connections[0].model_dir
    out_dir = _get_dir(args)
    problems_dir = Path(__file__).parent.parent / "data" / args.task / "problems"
    out_dir.mkdir(exist_ok=True, parents=True)

    assert len(textgen_api.connections.connections) == 1
    logger.info(f"Using engine: {textgen_api.connections.connections[0].model}")
    eval_data = []
    for sample_dir in tqdm.tqdm(natsorted(out_dir.glob("task-*"))):
        print("-" * 100)
        print(f"Sample dir: {sample_dir.name}")
        assert (sample_dir / "response.json").exists(), sample_dir
        data = json.loads((sample_dir / "response.json").read_text())
        sample_eval_data = {
            "sample_dir": sample_dir.name,
            "n_code": None,
        }
        if "code" in data:
            sample_eval_data["n_code"] = len(ast.unparse(ast.parse(data["code"])).splitlines())

        if "correct" in data:
            print(f"Already evaluated as {data['correct']}")
            if not data["correct"]:
                print(data["error"])
            correct = data["correct"]

        elif data["error"] is not None:
            print(f"Error: {data['error']}")
            correct = False

        else:
            idx = int(sample_dir.name.split("task-")[1]) + 1
            gt_goal = ProblemParser()((problems_dir / f"p{idx:02d}.pddl").read_text()).goal
            assert gt_goal is not None, f"Goal is None for task {idx:02d}"
            exp_goal = parse_formula(data["effect"], only_variables=False)
            correct = is_predicates_subset(gt_goal, exp_goal, enforce_parent_set_has_negated=False)

        data["correct"] = correct
        (sample_dir / "response.json").write_text(json.dumps(data, indent=2))

        sample_eval_data["correct"] = correct
        eval_data.append(sample_eval_data)

    df = pd.DataFrame(eval_data)

    out_json = {
        "engine": args.engine,
        "task": args.task,
        "tries": args.tries,
        "no_docstrings": args.no_docstrings,
        "count_correct": df["correct"].sum().item(),
        "total_count": len(df),
        "avg_plan_length": df[df["correct"]]["n_code"].mean().item(),
        "eval_data": df.to_dict(orient="records"),
    }
    (out_dir / "llm-baseline.json").write_text(json.dumps(out_json, indent=2))

    print(df)
    print("Count correct: ", df["correct"].sum())
    print("total_count:", len(df))
    print("Avg plan length: ", df[df["correct"]]["n_code"].mean())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Construct action models using LLMs.")
    parser.add_argument("--engine", type=str, help="Specify the LLM engine to use.", required=True)
    parser.add_argument("--task", type=str, help="Specify the task.", required=True)
    parser.add_argument(
        "--tries",
        type=int,
        default=1,
        help="Number of tries for generating a valid plan.",
    )
    parser.add_argument("--no-docstrings", action="store_true")
    parser.add_argument("--eval", action="store_true", help="Run evaluation instead of generation.")
    args = parser.parse_args()

    load_dotenv()

    if args.eval:
        run_evaluation(args)
    else:
        run_generation(args)
