import re
import ast
import json
import copy
import argparse
import shutil
import subprocess
from tqdm import tqdm
from pathlib import Path
from natsort import natsorted
from dotenv import load_dotenv
from pddl.logic.base import Not
from pddl.parser.domain import DomainParser

from llm_utils import TextGenApi
from state_estimation import PredicateGrounder
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.pddl_generators.pddl_llm_interface import SasPlan
from tp_lodge.utils.pddl_domain_syntax import parse_action, parse_formula
from tp_lodge.utils.pddl_utils import (
    get_valid_predicates,
    get_list_of_predicates,
)
from tp_lodge.utils.planning_cache_utils import parameterize_skill

from demos.ipc.src.common.utils import copy_domain_w_args
from demos.ipc.src.household.household_motion_validator import HouseholdMotionValidator
from demos.ipc.src.logistics.logistics_motion_validator import LogisticsMotionValidator
from demos.ipc.src.planning_benchmark_sample_generator import PlanningBenchmarkSampleGenerator
from demos.furniturebench.scripts.fb_variable_parser import FBVariableParser


def _get_plan(path):
    sub_dir = path / "sub-actions"
    if sub_dir.exists():
        plan_file = path / "ai-plan.cache.plan"
        if not plan_file.is_file():
            return []
        plan = SasPlan.from_string(plan_file.read_text())
        if (path / "generated-domain.json").is_file():
            domain = PDDLDomain.loads((path / "generated-domain.json").read_text())
        else:
            domain = PDDLDomain.loads((path / "init-domain.json").read_text())
        plan_so_far = []
        for i, sub_action in enumerate(plan.actions):
            sub_plan_dir = sub_dir / f"{i}-{sub_action.name}"
            if sub_plan_dir.exists():
                try:
                    operator = domain.get_operator(sub_action.name)
                except KeyError:
                    try:
                        if (path / "sub-actions/prev-domain.json").is_file():
                            operator = PDDLDomain.loads(
                                (path / "sub-actions/prev-domain.json").read_text()
                            ).get_operator(sub_action.name)
                    except KeyError:
                        return plan_so_far

                prim = operator.mapped_skill_sequence

                if prim is None:
                    prim_file = sub_plan_dir / "decide-whether-primitive.json"
                    if not prim_file.is_file():
                        return plan_so_far
                    prim = json.loads(prim_file.read_text())["func_calls"]

                if len(prim) == 1:
                    try:
                        parameterized_py_function = parameterize_skill(
                            skill=prim[0], sas_action=sub_action, action=operator
                        )
                        plan_so_far.append(parameterized_py_function)
                    except:
                        return plan_so_far
                else:
                    plan_so_far += _get_plan(sub_plan_dir)

            else:
                # thats how far lodge planned
                break

        return plan_so_far

    else:
        return []


def _get_transitions_from_ipc(args, data_dir: Path, out_dir: Path):
    generator = PlanningBenchmarkSampleGenerator(args.domain, use_hl_types=True)
    if args.domain == "household":
        motion_validator = HouseholdMotionValidator(hide_env_feedback=False)
    else:
        motion_validator = LogisticsMotionValidator(hide_env_feedback=False)

    all_plans = []
    transitions = []
    for task_dir in tqdm(natsorted(out_dir.glob("task-*"))):
        to_traverse = [task_dir]
        backups_dir = out_dir / "backups" / task_dir.name
        if backups_dir.exists():
            to_traverse += list((out_dir / "backups" / task_dir.name).iterdir())

        plans_for_task = {}
        for path in natsorted(to_traverse):

            plan = _get_plan(path)
            if len(plan) == 0:
                continue

            plans_for_task[str(plan)] = plan
        print(task_dir.name, len(list(plans_for_task.values())))

        task_id = int(task_dir.name.split("-")[-1])
        all_plans.append(
            {
                "task": task_dir.name,
                "task-id": task_id,
                "plan": list(plans_for_task.values()),
            }
        )

        instruction, env_state, problem_skeleton = generator.generate(task_id)
        domain_skeleton = PDDLDomain.loads((data_dir / "domain_skeleton.json").read_text())

        for plan in plans_for_task.values():
            assert hasattr(motion_validator, "env")
            assert hasattr(motion_validator.env, "_init_env_state")
            motion_validator.env._init_env_state = copy.deepcopy(env_state)  # type: ignore
            motion_validator.env._set_state(new_state=env_state)
            motion_validator.init_state(seed=0)

            transitions_for_plan = []
            current_state = get_valid_predicates(
                motion_validator.get_predicates_evaluation(domain=domain_skeleton, problem=problem_skeleton)
            )
            for i, action in enumerate(plan):
                try:
                    motion_validator._run_motion(action)
                except Exception as e:
                    # assert i == len(plan) - 1, f"Error in action {i} of plan {plan}: {e}"
                    print(f"Error in action {i} of plan {plan}: {e}")
                    break
                next_state = get_valid_predicates(
                    motion_validator.get_predicates_evaluation(domain=domain_skeleton, problem=problem_skeleton)
                )

                transitions_for_plan.append(
                    {
                        "state": current_state,
                        "action": action,
                        "action_name": action.split("(")[0],
                        "next_state": next_state,
                    }
                )

                current_state = next_state

            for t in transitions_for_plan:
                if t in transitions:
                    continue
                transitions.append(t)

    return transitions


def _get_transitions_from_fb(args, data_dir: Path, out_dir: Path, *, use_grounding: bool):
    # Implement the logic to get transitions from the FB domain
    reply_buffer = json.loads((out_dir / "reply_buffer/states.json").read_text())
    use_grounding = args.use_grounding

    def _get_predicates(state):
        if state["similar_state"] is not None:
            sim_state = reply_buffer["states"][state["similar_state"]]
            return _get_predicates(sim_state)
        else:
            predicates = state["predicates"]
            assert predicates is not None, state
            predicates = [f"(not {p})" if not e else p for p, e in predicates.items()]
            return get_list_of_predicates(parse_formula(f"(and {' '.join(predicates)})", only_variables=False))

    if use_grounding:
        grounder = PredicateGrounder(
            code_api_file=data_dir.parent.parent / "scripts/llm_code_interface.py",
            out_dir=out_dir / "grounding_code",
            textgen_api=None,
            domain_knowledge=None,
        )
        domain = PDDLDomain.loads((out_dir / "domain.json").read_text())
        states = {}
        for s_hash, state in reply_buffer["states"].items():
            preds = grounder.ground_state(
                predicates=domain.predicates,
                variables=[FBVariableParser().from_dict(v) for v in state["variables"]],
                verify=False,
            )
            states[s_hash] = preds

        transitions = []
        for s_hash, state in reply_buffer["states"].items():
            if state["prev_state_hash"] is None:
                continue

            transitions.append(
                {
                    "state": states[state["prev_state_hash"]],
                    "action": state["executed_skill"],
                    "action_name": state["executed_skill"].split("(")[0],
                    "next_state": states[s_hash],
                }
            )
    else:
        transitions = []
        llm_dir_name = out_dir.parent.parent.parent.name
        sample_dir_name = out_dir.name
        predicates = json.loads(
            (
                out_dir.parent.parent.parent.parent
                / "predicate-learning-eval"
                / llm_dir_name
                / "vlm"
                / sample_dir_name
                / "predictions.json"
            ).read_text()
        )
        for state_hash, state in reply_buffer["states"].items():
            if state["prev_state_hash"] is None:
                # init state
                continue

            # prev_state = reply_buffer["states"][state["prev_state_hash"]]

            map_state = lambda ps: [
                parse_formula(p, only_variables=False) if p_eval else Not(parse_formula(p, only_variables=False))
                for p, p_eval in ps.items()
            ]

            transitions.append(
                {
                    # "state": _get_predicates(prev_state),
                    "state": map_state(predicates[state["prev_state_hash"]]),
                    "action": state["executed_skill"],
                    "action_name": state["executed_skill"].split("(")[0],
                    # "next_state": _get_predicates(state),
                    "next_state": map_state(predicates[state_hash]),
                }
            )

    return transitions


def _parse_function_fb(expr):
    return {"name": expr.name, "arity": len(expr.args.args), "var_types": ["part" for _ in expr.args.args]}


def _parse_function_ipc(expr, args):
    # demos_dir = _demos_dir(args)
    # data_dir = demos_dir / "data" / args.domain

    # functions = (data_dir / "function_stubs.py").read_text()
    # func_info ={}
    # for func in ast.parse(functions).body:
    #     assert isinstance(func, ast.FunctionDef)
    #     func_info[func.name] = func
    if args.domain == "logistics":

        def map_var_type(arg: str) -> str:
            existing_var_types = ["city", "location", "package", "plane", "truck"]
            for existing_var_type in existing_var_types:
                if arg in existing_var_type or existing_var_type in arg:
                    return existing_var_type
            raise ValueError(f"Unknown var type for arg {arg}")

        return {
            "name": expr.name,
            "arity": len(expr.args.args),
            "var_types": [map_var_type(a.arg) for a in expr.args.args],
        }
    else:

        def map_var_type(arg: str) -> str:
            existing_var_types = {
                "agent": None,
                "household_object": None,
                "furniture_appliance": None,
                "knife": "household_object",
                "food": "household_object",
                "microwave": "furniture_appliance",
                "pan": "household_object",
                "receptacle": "household_object",
                "blender": "furniture_appliance",
                "surface": "furniture_appliance",
                "cloth": "household_object",
                "vacuum_cleaner": "household_object",
                "carpet": "furniture_appliance",
                "trash_can": "furniture_appliance",
            }
            for existing_var_type, parent in existing_var_types.items():
                existing_var_type_splits = existing_var_type.split("_")
                if any(arg in split or split in arg for split in existing_var_type_splits):
                    return parent if parent is not None else existing_var_type
            raise ValueError(f"Unknown var type for arg {arg}")

        return {
            "name": expr.name,
            "arity": len(expr.args.args),
            "var_types": [map_var_type(a.arg) for a in expr.args.args],
        }


def _demos_dir(args):
    root_dir = Path(__file__).parent.parent
    is_fb = args.domain.startswith("fb-")
    return root_dir / ("furniturebench" if is_fb else "ipc")


def _data_source_dir(args):
    llm = args.llm
    is_fb = args.domain.startswith("fb-")
    model_dir = TextGenApi.default(llm).connections.connections[0].model_dir
    suffix = "planning-with-pred-learning" if is_fb else "iclr-w-dk-w-ai-shared"
    return (_demos_dir(args) / f"results/{args.domain}/{model_dir}/hi-tamp/{suffix}").glob("sample-*")


def _load_data(handover_dir: Path, args, data_source_dir: Path):
    is_fb = args.domain.startswith("fb-")

    demos_dir = _demos_dir(args)
    data_dir = demos_dir / "data" / args.domain

    function_stubs = ast.parse((data_dir / "function_stubs.py").read_text()).body

    if is_fb:
        transitions = _get_transitions_from_fb(args, data_dir, data_source_dir, use_grounding=args.use_grounding)
    else:
        transitions = _get_transitions_from_ipc(args, data_dir, data_source_dir)

    action_predicates = {
        expr.name: _parse_function_fb(expr) if is_fb else _parse_function_ipc(expr, args) for expr in function_stubs
    }

    if is_fb:
        gen_domain = PDDLDomain.loads((data_source_dir / "domain.json").read_text())
        predicates = [p.definition for p in gen_domain.predicates]
    else:
        gen_domain = DomainParser()((data_dir / "domain.pddl").read_text())
        predicates = gen_domain.predicates
    pddl_preds = {
        p.name: {
            "name": p.name,
            "arity": len(p.terms),
            "var_types": [list(t.type_tags)[0] for t in p.terms],
        }
        for p in predicates
    }

    def parse_preds(ps):
        return [
            {"predicate_name": p.name, "variables": [p.name for p in p.terms]} for p in ps if not isinstance(p, Not)
        ]

    def parse_action(a):
        a = ast.parse(a).body[0].value

        predicate = a.func.id
        variables = [arg.n for arg in a.args]
        variables += [kw.value.n for kw in a.keywords]

        return {"action_pred_name": predicate, "variables": variables}

    episode = [
        # state, action, next-state, None
        (parse_preds(t["state"]), parse_action(t["action"]), parse_preds(t["next_state"]), None)
        for t in transitions
    ]

    (handover_dir / "preds.json").write_text(json.dumps(pddl_preds, indent=2))
    (handover_dir / "action-preds.json").write_text(json.dumps(action_predicates, indent=2))
    (handover_dir / "episode.json").write_text(json.dumps(episode, indent=2))


def _out_dir(args):
    demos_root_dir = Path(__file__).parent.parent / ("furniturebench" if args.domain.startswith("fb-") else "ipc")
    connection = TextGenApi.default(args.llm).connections.connections[0]
    out_dir = demos_root_dir / "results" / args.domain / connection.model_dir / "cluster-intersect"
    if args.domain.startswith("fb-"):
        out_dir = out_dir / ("grounder-based" if args.use_grounding else "vlm-based")
    return out_dir


def _cluster_dir():
    ipc_root_dir = Path(__file__).parent
    return ipc_root_dir.parent.parent / "3rdparty/LOFT_IROS_2021"


def run_cluster(args):
    load_dotenv()

    exp_dir = _out_dir(args)
    exp_dir.mkdir(parents=True, exist_ok=True)

    for data_source_dir in _data_source_dir(args):
        exp_sample_dir = exp_dir / data_source_dir.name
        print("Running", data_source_dir)
        exp_sample_dir.mkdir(exist_ok=True)

        # if (exp_sample_dir / "operators.txt").is_file():
        #     print(f"Skipping {exp_dir} as it already exists.")
        #     return

        _load_data(handover_dir=exp_sample_dir, args=args, data_source_dir=data_source_dir)

        baseline_root = _cluster_dir()
        python_path = baseline_root / ".pixi/envs/default/bin/python"

        response = subprocess.run(
            [
                python_path,
                baseline_root / "learn-nrst.py",
                "--out_dir",
                exp_sample_dir,
            ],
            cwd=baseline_root,
        )
        if response.returncode != 0:
            print("Error running learn-nrst.py:")
            print(response.returncode)
            if response.stdout is not None:
                print(response.stdout.decode("utf-8"))
            if response.stderr is not None:
                print(response.stderr.decode("utf-8"))
            raise RuntimeError("Failed to run learn-nrst.py")


def post_process_cluster(args):
    exp_dir = _out_dir(args)
    is_fb = args.domain.startswith("fb-")

    demos_dir = _demos_dir(args)
    data_dir = demos_dir / "data" / args.domain
    gt_domain = DomainParser()((data_dir / "domain.pddl").read_text())

    for data_source_dir in _data_source_dir(args):
        exp_sample_dir = exp_dir / data_source_dir.name

        if is_fb:
            domain = PDDLDomain.loads((data_source_dir / "domain.json").read_text()).to_pddl()
            shutil.copytree(data_source_dir / "problem", exp_sample_dir / "problem", dirs_exist_ok=True)
        else:
            domain = DomainParser()((data_dir / "domain.pddl").read_text())
            shutil.copytree(data_dir / "problems", exp_sample_dir / "problem", dirs_exist_ok=True)

        operators = (exp_sample_dir / "operators.txt").read_text()

        start_idcs = []
        for operator in re.finditer(r"(\(:action\s+\w+)", operators):
            start_idcs.append(operator.start())
        start_idcs.append(len(operators))

        operators_list = []
        for i in range(len(start_idcs) - 1):
            operators_list.append(parse_action(operators[start_idcs[i] : start_idcs[i + 1]].strip()))

        domain = copy_domain_w_args(domain, actions=operators_list, name=gt_domain.name)

        (exp_sample_dir / "domain.pddl").write_text(str(domain))


def main(args):
    run_cluster(args)
    post_process_cluster(args)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--domain", type=str, default="household", required=True)
    argparser.add_argument("--llm", type=str, required=True)
    # argparser.add_argument("--use-grounding", action="store_true")
    load_dotenv()
    for use_grounding in [True, False]:
        args = argparser.parse_args()
        args.use_grounding = use_grounding
        main(args)
