import logging
import re
from typing import List, Optional, Tuple, Dict, overload

from lark.exceptions import UnexpectedToken, VisitError
from llm_utils import Chat
from llm_utils.prompt_generation.utils import replace_text
from pddl.core import And, Constant, Formula
from pddl.exceptions import PDDLValidationError
from pddl.logic.base import Not
from pddl.parser.domain import DomainParser
from pddl.parser.problem import ProblemParser
from python_utils.string_utils import remove_comments, get_markup_from_text

from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_operation import PDDLOperation
from tp_lodge.task_planning.models.pddl.pddl_operator import PDDLOperator
from tp_lodge.task_planning.models.pddl.pddl_predicate import PDDLPredicate
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.utils.pddl_domain_syntax import parse_action, parse_formula, parse_predicate
from tp_lodge.utils.pddl_utils import (
    get_pred_change,
    get_predicate_evaluation,
    get_valid_predicates,
    pddls_from_text,
)
from tp_lodge.utils.pddl_verify_utils import verify_action

logger = logging.getLogger(__name__)


@overload
def _get_section(
    text: str, section_name: str, assert_one: bool = True, remove_none: bool = True
) -> Optional[Tuple[int, int]]: ...
@overload
def _get_section(
    text: str, section_name: str, assert_one: bool, remove_none: bool = True
) -> Optional[List[Tuple[int, int]]]: ...
def _get_section(text: str, section_name: str, assert_one: bool = True, remove_none: bool = True):
    pattern = rf"(?:\n|^)\s*### .*{section_name}.*\n *([\w\W]*?)(?=\n *\n *###|---|$)"
    matches = list(re.finditer(pattern, text))
    matches = [
        match.span(1)
        for match in matches
        if not remove_none or all(term not in match.group(1).lower() for term in ["none", "no change"])
    ]
    if len(matches) > 0:
        if assert_one:
            if len(matches) > 1:
                raise ValueError(
                    "Multiple sections found for %s. Only one header (###) should contain that name" % section_name
                )
            else:
                return matches[0]
        else:
            return matches
    return None


def parse_predicate_list_from_text(
    text: str,
    init_formula: Optional[Formula],
    existing_domain: PDDLDomain,
    existing_objects: List[Constant],
    supported_predicates: Optional[List[str]] = None,
) -> Tuple[Formula, List[ValueError]]:
    text = remove_comments(text)
    formula_eval = get_predicate_evaluation(init_formula) if init_formula is not None else {}
    errors = []
    parsed = False
    response = get_markup_from_text(text, ["pddl", "lisp"])
    if response is not None and len(response) == 1:
        r = response[0].strip()
        if ":goal" in r:
            match = re.search(r":goal(.+)\)", r)
            if match is not None:
                r = match.group(1).strip()

        formula = parse_formula(r, only_variables=False)
        try:
            formula_eval = get_predicate_evaluation(formula)
        except RuntimeError as e:
            errors.append(ValueError("Predicate parsing error: %s" % str(e)))
        parsed = True

    if not parsed:
        for line in text.strip().splitlines():
            if "None" in line or "no change" in line.lower():
                break
            if line.strip() == "":
                if len(formula_eval) > 0:
                    # we already parsed some predicates, so we can stop here
                    break
                else:
                    continue

            if ":" not in line:
                raise ValueError("Invalid list of predicates: %s" % text)

            predicate_str, eval_str = line.split(":")
            eval_str = eval_str.strip().lower()
            if predicate_str.strip().startswith("-"):
                predicate_str = predicate_str[predicate_str.index("(") :]
            if "(" not in predicate_str and ")" not in predicate_str:
                predicate_str = f"({predicate_str.strip()})"

            predicate = parse_formula(predicate_str.strip(), only_variables=False)
            predicate_evals = get_predicate_evaluation(predicate)
            assert len(predicate_evals) == 1
            predicate = list(predicate_evals.keys())[0]
            predicate_eval = list(predicate_evals.values())[0]
            if eval_str == "false":
                assert predicate_eval is True
                predicate_eval = False
            remove_predicate = eval_str == "remove"

            if supported_predicates is not None and predicate.name in supported_predicates:
                # if we hav a list of supported predicates, we need to check if the predicate is in the list.
                # predicates in that list are not allowed to changed
                continue

            if remove_predicate:
                if predicate in formula_eval:
                    del formula_eval[predicate]
            else:
                formula_eval[predicate] = predicate_eval

    for predicate in formula_eval.keys():
        if any(not isinstance(t, Constant) for t in predicate.terms):
            errors.append(ValueError("Predicate %s must not contain variables as arguments" % line))
            continue

        existing_domain.verify_predicate(predicate, objects=existing_objects)

    updated_formula = And(*[pred if eval else Not(pred) for pred, eval in formula_eval.items()])

    return updated_formula, errors


def parse_problem_from_text(
    problem: PDDLProblem, existing_domain: PDDLDomain, text: str, supported_predicates: Optional[List[str]]
) -> Tuple[PDDLProblem, List[ValueError]]:
    old_problem = problem.copy_with()
    errors = []
    try:
        response = _get_section(text, "Goal State")

        if response is not None:
            response = text[response[0] : response[1]]
            goal_state, new_errors = parse_predicate_list_from_text(
                response,
                problem.goal_state,
                existing_domain=existing_domain,
                existing_objects=problem.objects,
                supported_predicates=None,
            )
            if len(new_errors) > 0:
                errors.append(
                    ValueError("Goal state syntax error:\n%s" % ("\n".join(map(lambda e: "- %s" % str(e), new_errors))))
                )
            problem = problem.copy_with(goal_state=goal_state)
    except ValueError as e:
        errors.append(e)
    except AssertionError as e:
        errors.append(ValueError("Goal state parsing error: %s" % str(e)))

    try:
        response = _get_section(text, "(?:Init|Initial) State")

        if response is not None:
            response = text[response[0] : response[1]]
            init_state, new_errors = parse_predicate_list_from_text(
                response,
                problem.initial_state,
                existing_domain=existing_domain,
                existing_objects=problem.objects,
                supported_predicates=supported_predicates,
            )
            if len(new_errors) > 0:
                errors.append(
                    ValueError("Init state syntax error:\n%s" % ("\n".join(map(lambda e: "- %s" % str(e), new_errors))))
                )
            problem = problem.copy_with(initial_state=init_state)
    except ValueError as e:
        errors.append(e)
    except AssertionError as e:
        errors.append(ValueError("Init state parsing error: %s" % str(e)))

    # check with fallback parser
    if problem.goal_state is None:
        errors.append(ValueError("Please ensure that the goal state is set."))
    try:
        ProblemParser()(str(problem.to_pddl(force=True)))  # init state might be None
    except Exception as e:
        errors.append(ValueError("Problem syntax error: %s" % str(e)))
        problem = old_problem  # fallback to old problem

    return problem, errors


def get_predicate_change_text(
    before: Formula,
    after: Formula,
) -> str:
    effects, removed, added = get_pred_change(before, after)

    response = []
    for effect in effects + added:
        if isinstance(effect, Not):
            response.append(f"{effect.argument}: false")
        else:
            response.append(f"{effect}: true")
    for effect in removed:
        response.append(f"{effect}: remove")
    assert len(response) < 50, "plausibility check failed, too many changes: %d" % len(response)
    return "\n".join(response)


def inject_problem_into_text(
    problem: PDDLProblem, text: str, *, prev_problem: Optional[PDDLProblem] = None, only_new: bool = False
) -> str:
    response = _get_section(text, "Objects", remove_none=False)
    if response is not None:
        text = replace_text(text, response, problem.get_objects_str())

    response = _get_section(text, "(?:Init|Initial) State", remove_none=False)
    if response is not None and problem.initial_state is not None:
        if only_new:
            assert problem.grounder_initial_state is not None
            effects, removed, added = get_pred_change(problem.grounder_initial_state, problem.initial_state)
            effects = list(set(effects + [Not(r) for r in removed] + [a for a in added]))

            content = get_predicate_change_text(And(), effects)
        else:
            initial_state = problem.initial_state
            content = " ".join(map(str, get_valid_predicates(initial_state)))

        text = replace_text(text, response, content)

    response = _get_section(text, "Goal State", remove_none=False)
    if response is not None:
        assert problem.goal_state is not None

        if only_new and prev_problem is not None:  # prev_problem is not None:
            content = get_predicate_change_text(prev_problem.goal_state or And(), problem.goal_state)
        else:
            # not ok, since goal state list predicatas that should hold, not just effect (if predicate is true in init state, its is otherwise respected)
            # content = str(get_effects_from_pred_change(problem.initial_state, problem.goal_state))
            content = str(problem.goal_state)

        text = replace_text(text, response, content)

    return text


def parse_domain_from_text(domain: PDDLDomain, text: str) -> Tuple[PDDLDomain, List[ValueError]]:
    old_domain = domain.copy_with()
    found_section = False
    errors = []
    try:
        response = _get_section(text, r"Predicate\(?s?\)?")
    except ValueError as e:
        errors.append(e)
        response = None
    if response is not None:
        response = text[response[0] : response[1]]
        found_section = True
        predicates = {p.definition.name: p for p in domain.predicates}
        for line in response.splitlines():
            if "None" in line:
                continue
            line = line.strip()
            if not line.startswith("-"):
                continue

            matches = re.match(r"\-?[\s\*]*(\(.*\))[\*]*\: (.*?)\. (.*)", line)
            if matches is None:
                errors.append(ValueError("Predicate parsing error: %s" % line))
                continue

            try:
                predicate = parse_predicate(matches.group(1))
            except ValueError as e:
                errors.append(e)
                continue

            pred_type = matches.group(2)
            description = matches.group(3)
            prev_predicate = predicates.get(predicate.name)
            if prev_predicate is not None:
                if prev_predicate.predefined:
                    # if the predicate is predefined, we do not want to overwrite it
                    logger.info("- Skipped predefined PDDL predicate `%s`" % predicate.name)
                    continue
                new_predicate = prev_predicate.copy_with(
                    definition=predicate,
                    description=description,
                    pred_type=pred_type,
                    newly_generated=prev_predicate.newly_generated,
                )
                logger.info("- Edited PDDL predicate `%s`" % predicate.name)
            else:
                try:
                    new_predicate = PDDLPredicate(
                        definition=predicate, description=description, newly_generated=True, pred_type=pred_type
                    )
                except AssertionError as e:
                    errors.append(ValueError("Predicate parsing error: %s" % str(e)))
                    continue
                logger.info("- Added PDDL predicate `%s`" % new_predicate.name)

            predicates[predicate.name] = new_predicate
        domain = domain.copy_with(predicates=list(predicates.values()))

    responses = _get_section(text, r"[Aa]ction\(?s?\)?", assert_one=False)
    responses = responses or []
    responses = [text[response[0] : response[1]] for response in responses]
    if len(responses) > 1:
        # TODO: hotfix needed for decomposition response
        responses = list(filter(lambda m: "high-level" not in m.lower(), responses))
    if len(responses) > 1:
        errors.append(
            ValueError(
                "Multiple sections found that could contain action definition updates. Combine the changes in the sections into one `Action` section that lists the add/edit/remove changes."
            )
        )
        response = None
    elif len(responses) == 1:
        found_section = True
        response = responses[0]
        numbering_rgx = r"(?:[\d]+\.\s)"

        action_strs = re.finditer(rf"({numbering_rgx}[\w\W]*?)(?=\n\s*{numbering_rgx}|$)", response)
        actions = {a.name: a for a in domain.operators}
        adapted_actions = set()
        for action_str in action_strs:
            action_str = action_str.group(1).strip()
            match = re.match(rf"{numbering_rgx}\**(.*?)\**\:?\**\s*(add|edit|delete)?(?:\s*|\s+.*)(?:\n|$)", action_str)
            if match is None:
                errors.append(ValueError("Syntax error parsing actions: `%s`" % action_str))
                continue
            name = match.group(1).lower().strip()
            operation_type = (match.group(2) or "add").lower()

            description = re.search(r"- \**Description\**:?\** (.*)\n", action_str)
            if description is not None:
                description = description.group(1).strip()

            if name in adapted_actions:
                errors.append(ValueError("Action `%s` was already adapted once. Only adapt each action once." % name))
                continue

            if operation_type == "add" or operation_type == "edit":
                pddls_str = pddls_from_text(action_str)
                if len(pddls_str) == 0:
                    errors.append(
                        ValueError(
                            "Action parsing error: %s. Only list actions where the definition must be changed"
                            % action_str
                        )
                    )
                    continue
                pddl_str = pddls_str[0].strip()

                try:
                    pddl_definition = parse_action(action_str=pddl_str.strip())
                    verify_action(pddl_definition, domain)
                except ValueError as e:
                    errors.append(ValueError("%s: %s" % (name, str(e))))
                    continue

                prev_action = actions.get(pddl_definition.name)
                if prev_action is not None:
                    description = prev_action.description
                    prev_action.update_inplace(
                        definition=pddl_definition,
                        description=description,
                    )
                    action = prev_action
                    logger.info("- Edited PDDL action `%s`" % action.name)
                else:
                    action = PDDLOperator(definition=pddl_definition, description=description or "")
                    logger.info("- Added PDDL action `%s`" % action.name)
                actions[action.name] = action
                adapted_actions.add(action.name)

            elif operation_type == "delete":
                if name not in actions:
                    # errors.append(ValueError("Tried to remove action %s, but that action does not exist" % name))
                    continue
                logger.info("- Deleted PDDL action `%s`" % name)
                del actions[name]
                adapted_actions.add(name)

            else:
                errors.append(ValueError("Invalid operation type %s" % operation_type))

        domain = domain.copy_with(operators=list(actions.values()))

    if old_domain == domain and not found_section and "(:action" in text and len(errors) == 0:
        errors = [
            ValueError(
                "No action section found in the text, but the text contains an action definition. Ensure you use the correct syntax."
            )
        ]

    # check with fallback parser
    try:
        DomainParser()(str(domain.to_pddl()))
    except (PDDLValidationError, UnexpectedToken, VisitError) as e:
        errors.append(ValueError("Domain syntax error: %s" % str(e)))
        domain = old_domain  # fallback to old domain

    return domain, errors


def get_action_text(action: PDDLOperator) -> str:
    if action.last_operation == PDDLOperation.REMOVE:
        return f"{action.name}: delete"

    return """%s: %s
    - Description: %s
    - PDDL Definition:
        ```pddl
%s        
        ```""" % (
        action.name,
        action.last_operation.to_string(),
        action.description,
        "\n".join([(" " * 8 + line) for line in str(action.definition).splitlines()]),
    )


def get_domain_text(domain: PDDLDomain) -> str:
    return inject_domain_into_text(
        domain,
        """### Types
will be replaced by types

### Predicates
will be replaced by predicates

### Actions
will be replaced by actions""",
    )


def get_problem_text(problem: PDDLProblem) -> str:
    return inject_problem_into_text(
        problem,
        """### Objects
will be replaced by objects

### Initial State
will be replaced by initial state

### Goal State
will be replaced by goal state""",
    )


def inject_domain_into_text(domain: PDDLDomain, text: str, only_new: bool = False) -> str:
    response = _get_section(text, "Predicate\(?s\)?", remove_none=False)
    if response is not None:
        predicates = domain.predicates
        if only_new:
            predicates = [pred for pred in predicates if pred.newly_generated]
        predicates_str = "\n".join(
            [f"- {pred.definition_str()}. {pred.pred_type}: {pred.description}" for pred in predicates]
        )
        text = replace_text(text, response, predicates_str)

    response = _get_section(text, "Types", remove_none=False)
    if response is not None:
        types_by_parent = domain.get_types_by_parents()
        types_str = "\n".join([f"%s - %s" % (" ".join(entries), p_type) for p_type, entries in types_by_parent.items()])
        text = replace_text(text, response, types_str)

    responses = _get_section(text, "[Aa]ction\(?s?\)?", assert_one=False)
    if responses is not None:
        responses = list(filter(lambda m: "high-level" not in text[: m[0]].strip().splitlines()[-1].lower(), responses))
        if len(responses) > 0:
            if len(responses) > 1:
                raise ValueError("Multiple sections found")
            response = responses[0]
            header = text[: response[0]].strip().splitlines()[-1]
            assert "###" in header
            actions = domain.operators
            if only_new:
                actions = [action for action in actions if action.last_operation != PDDLOperation.NONE]
            actions_str = "\n".join(["%d. %s" % (i + 1, get_action_text(action)) for i, action in enumerate(actions)])
            text = replace_text(text, response, actions_str)

    return text


def inject_defs_into_text(
    domain: PDDLDomain,
    problem: PDDLProblem,
    text: str,
    only_new: bool = False,
    prev_problem: Optional[PDDLProblem] = None,
) -> str:
    text = inject_domain_into_text(domain, text, only_new=only_new)
    text = inject_problem_into_text(problem, text, prev_problem=prev_problem, only_new=only_new)
    return text


def inject_definitions_into_chat(
    chat: Chat,
    operator_changes: Dict[PDDLOperation, List[PDDLOperator]],
    predicate_changes: Dict[PDDLOperation, List[PDDLPredicate]],
    domain: PDDLDomain,
    problem: PDDLProblem,
    prev_problem: Optional[PDDLProblem] = None,
) -> Chat:
    domain = domain.copy_with(
        operators=[op.copy_with(last_operation=change) for change, ops in operator_changes.items() for op in ops],
        predicates=[
            pred.copy_with(newly_generated=True) for change, preds in predicate_changes.items() for pred in preds
        ],
    )

    last_message = chat.last_message()
    explanation_sections = _get_section(text=last_message, section_name="Explanation", assert_one=False)

    placeholder_text = """
### Change/Add Actions
will be replaced by actions

### Change/Add Predicates
will be replaced by predicates

### Change Initial State
will be replaced by initial state

### Change Goal State
will be replaced by goal state"""

    text = placeholder_text

    text = inject_domain_into_text(domain, text, only_new=True)
    text = inject_problem_into_text(problem, text, prev_problem=prev_problem, only_new=True)

    sections = []
    if explanation_sections is not None and len(explanation_sections) > 0:
        explanation_text = "\n\n".join([last_message[s[0] : s[1]] for s in explanation_sections])
        sections.append(("Explanation", explanation_text))

    for sec_name in ["Actions", "Predicates", "Initial State", "Goal State"]:
        response = _get_section(text, sec_name, remove_none=False)
        if response is not None:
            response = text[response[0] : response[1]]
            if response.strip() == "":
                continue

            sections.append((sec_name, response))

    output_text = ""
    for sec_name, content in sections:
        output_text += f"### {sec_name}\n{content.strip()}\n\n"

    return chat.replace_last_message(output_text)
