from functools import partial
import json
from pathlib import Path
from typing import Callable, Dict, List, Optional
from pddl.parser.domain import DomainParser
from pddl.parser.problem import ProblemParser

from tp_lodge.task_planning.models.sas.sas_action import SasAction
from tp_lodge.task_planning.pddl_planner.ai_planner.fastdownward_ai_planner import FastDownwardAIPlanner
from tp_lodge.task_planning.pddl_generators.pddl_llm_interface import SasPlan
from tp_lodge.utils.pddl_lib_utils import copy_domain_w_args

def domain_data_dir(task: str) -> Path:
    out_dir = Path(__file__).parent.parent.parent
    if task.startswith("fb-"):
        out_dir = out_dir.parent / "furniturebench"

    return out_dir / "data" / task

def get_task_centric_domain(
    domain: str,
    config: dict,
    task: str,
    task_name: str,
    skills: Optional[List[str]],
) -> str:
    """Prune the domain to only include actions that are required for the given task"""
    domain = domain.replace("(or )", "(and)")  # hotfix
    parsed_domain = DomainParser()(domain)

    if skills is None:
        sample_plan = get_sample_plan(task=task, task_nr=task_name)
        function_mapping_config = get_function_mapping(task=task)
        function_mapping = get_function_mapping_from_config(function_mapping_config)
        skills = list(set([function_mapping[action.name](action) for action in sample_plan.actions]))

    skill_names = [skill.split("(")[0] for skill in skills]

    if len(skill_names) == 0:
        return domain

    required_actions = [action for action in parsed_domain.actions if config.get(action.name, None) in skill_names]
    parsed_domain = copy_domain_w_args(parsed_domain, actions=required_actions)
    domain = str(parsed_domain)
    domain = domain.replace("(or )", "(and)")  # hotfix

    return domain


def get_sample_plan(task: str, task_nr: str) -> SasPlan:
    """given gt domain and problem, generate a sample plan and save on disk"""
    out_dir = domain_data_dir(task) / "problems"

    task_idx = int(task_nr.split("task-")[1])

    problem_idx = task_idx + 1  # pddl files are 1-indexed

    sample_plan_file = out_dir / f"p{problem_idx:02d}-sample_plan.plan"
    if not sample_plan_file.is_file():
        print(f"Sample plan file {sample_plan_file} does not exist, generating it...")
        domain = DomainParser()((out_dir.parent / "domain.pddl").read_text())
        problem = ProblemParser()((out_dir / f"p{problem_idx:02d}.pddl").read_text())

        eval, succ, plan = FastDownwardAIPlanner().plan(domain=domain, problem=problem)
        assert succ, f"Planning failed for p{problem_idx}: {eval}"
        assert plan is not None, "No plan found"

        sample_plan_file.write_text(plan.to_string())

    sample_plan = SasPlan.from_string(sample_plan_file.read_text())

    return sample_plan


def get_function_mapping(task: str) -> dict:
    """Get the function mapping for the given task"""
    out_dir = domain_data_dir(task) / "problems"
    function_mapping_file = out_dir.parent / "function_mapping.json"

    if not function_mapping_file.is_file():
        raise FileNotFoundError(f"Function mapping file {function_mapping_file} does not exist.")

    return json.loads(function_mapping_file.read_text())


def map_action(
    function_mapping: dict, *, gt_function_signatures: Optional[Dict[str, int]] = None
) -> Callable[[SasAction], str]:
    """Map action names using the provided function mapping."""

    def mapping_func(sas_action: SasAction, *, mapping_config, gt_function_signatures: Optional[Dict[str, int]]) -> str:
        new_name = mapping_config["name"]
        args_i = [i for i in mapping_config["arg_mapping"] if i is not None]
        if len(args_i) == 0:
            args = []
        else:
            args = [None] * (max(args_i) + 1)
        for arg, pos in zip(sas_action.args, mapping_config["arg_mapping"]):
            if pos is not None:
                args[pos] = arg

        assert all(
            arg is not None for arg in args
        ), f"Not all arguments are mapped for action '{sas_action.name}': {args}"

        if gt_function_signatures is not None:
            assert (
                new_name in gt_function_signatures
            ), f"Function '{new_name}' not found in ground truth signatures: {gt_function_signatures.keys()}"
            assert len(args) == gt_function_signatures[new_name], (
                f"Function '{new_name}' expects {gt_function_signatures[new_name]} arguments, "
                f"but got {len(args)}: {args}"
            )

        func_call = "{f_name}({args})".format(f_name=new_name, args=", ".join(args))

        return func_call

    return partial(mapping_func, mapping_config=function_mapping, gt_function_signatures=gt_function_signatures)


def get_function_mapping_from_config(
    config: Dict[str, dict], *, gt_function_signatures: Optional[Dict[str, int]] = None
) -> Dict[str, Callable[[SasAction], str]]:
    return {k: map_action(v, gt_function_signatures=gt_function_signatures) for k, v in config.items()}
