from collections import defaultdict
import json
import copy
from pathlib import Path
from textwrap import indent, dedent
from typing import List

from dotenv import load_dotenv
import frozenlist
from llm_utils import TextGenApi
from natsort import natsorted
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.pddl_generators.pddl_llm_interface import SasPlan
from tp_lodge.utils.pddl_utils import (
    get_pred_change,
    get_predicate_evaluation,
    get_valid_predicates,
    get_variables_in_term,
    get_constants_in_term,
)
from tp_lodge.utils.planning_cache_utils import parameterize_skill
from tqdm import tqdm

from demos.ipc.src.household.household_motion_validator import HouseholdMotionValidator
from demos.ipc.src.planning_benchmark_sample_generator import PlanningBenchmarkSampleGenerator


def _get_plan(path):
    sub_dir = path / "sub-actions"
    if sub_dir.exists():
        plan_file = path / "ai-plan.cache.plan"
        if not plan_file.is_file():
            return []
        plan = SasPlan.from_string(plan_file.read_text())
        if (path / "generated-domain.json").is_file():
            domain = PDDLDomain.loads((path / "generated-domain.json").read_text())
        else:
            domain = PDDLDomain.loads((path / "init-domain.json").read_text())
        plan_so_far = []
        for i, sub_action in enumerate(plan.actions):
            sub_plan_dir = sub_dir / f"{i}-{sub_action.name}"
            if sub_plan_dir.exists():
                try:
                    operator = domain.get_operator(sub_action.name)
                except KeyError:
                    try:
                        if (path / "sub-actions/prev-domain.json").is_file():
                            operator = PDDLDomain.loads(
                                (path / "sub-actions/prev-domain.json").read_text()
                            ).get_operator(sub_action.name)
                    except KeyError:
                        return plan_so_far

                prim = operator.mapped_skill_sequence

                if prim is None:
                    prim_file = sub_plan_dir / "decide-whether-primitive.json"
                    if not prim_file.is_file():
                        return plan_so_far
                    prim = json.loads(prim_file.read_text())["func_calls"]

                if len(prim) == 1:
                    try:
                        parameterized_py_function = parameterize_skill(
                            skill=prim[0], sas_action=sub_action, action=operator
                        )
                        plan_so_far.append(parameterized_py_function)
                    except:
                        return plan_so_far
                else:
                    plan_so_far += _get_plan(sub_plan_dir)

            else:
                # thats how far lodge planned
                break

        return plan_so_far

    else:
        return []


def _get_transitions(out_dir):
    generator = PlanningBenchmarkSampleGenerator("household", use_hl_types=True)
    motion_validator = HouseholdMotionValidator(hide_env_feedback=False)

    all_plans = []
    transitions = []
    for task_dir in tqdm(natsorted(out_dir.glob("task-*"))):
        to_traverse = list((out_dir / "backups" / task_dir.name).iterdir()) + [task_dir]

        plans_for_task = {}
        for path in natsorted(to_traverse):

            plan = _get_plan(path)
            if len(plan) == 0:
                continue

            plans_for_task[str(plan)] = plan
        print(task_dir.name, len(list(plans_for_task.values())))

        task_id = int(task_dir.name.split("-")[-1])
        all_plans.append(
            {
                "task": task_dir.name,
                "task-id": task_id,
                "plan": list(plans_for_task.values()),
            }
        )

        instruction, env_state, problem_skeleton = generator.generate(task_id)
        data_dir = root_dir / "data/household"
        domain_skeleton = PDDLDomain.loads((data_dir / "domain_skeleton.json").read_text())

        for plan in plans_for_task.values():
            assert hasattr(motion_validator, "env")
            assert hasattr(motion_validator.env, "_init_env_state")
            motion_validator.env._init_env_state = copy.deepcopy(env_state)  # type: ignore
            motion_validator.env._set_state(new_state=env_state)
            motion_validator.init_state(seed=0)

            transitions_for_plan = []
            current_state = get_valid_predicates(
                motion_validator.get_predicates_evaluation(domain=domain_skeleton, problem=problem_skeleton)
            )
            for i, action in enumerate(plan):
                try:
                    motion_validator._run_motion(action)
                except Exception as e:
                    # assert i == len(plan) - 1, f"Error in action {i} of plan {plan}: {e}"
                    print(f"Error in action {i} of plan {plan}: {e}")
                    break
                next_state = get_valid_predicates(
                    motion_validator.get_predicates_evaluation(domain=domain_skeleton, problem=problem_skeleton)
                )

                transitions_for_plan.append(
                    {
                        "state": current_state,
                        "action": action,
                        "action_name": action.split("(")[0],
                        "next_state": next_state,
                    }
                )

                current_state = next_state

            for t in transitions_for_plan:
                if t in transitions:
                    continue
                transitions.append(t)

    return transitions


load_dotenv()
root_dir = Path(__file__).parent.parent

llm = "gpt4.1-mini"
suffix = "iclr-w-dk-w-ai-shared"

model_dir = TextGenApi.default(llm).connections.connections[0].model_dir
out_dir = root_dir / f"results/household/{model_dir}/hi-tamp/{suffix}/sample-0"
function_stubs = (root_dir / "data/household/function_stubs.py").read_text()


transitions = _get_transitions(out_dir)
# transitions_per_action = defaultdict(list)
# for transition in transitions:
#     transitions_per_action[transition["action_name"]].append(transition)
print("test")


def compute_object_mapping(effects, cluster_center_effects, positive: bool):
    effects = [e for e in effects if isinstance(e, Not) != positive]
    cluster_center_effects = [e for e in effects if isinstance(e, Not) != positive]


action_json = [
    {"name": "go_to_a_furniture_piece_or_an_appliance", "arg_types": ["furniture_appliance"]},
    {
        "name": "pick_up_an_object_on_or_in_a_furniture_piece_or_an_appliance",
        "arg_types": ["household_object", "furniture_appliance"],
    },
    {
        "name": "put_an_object_on_or_in_a_furniture_piece_or_an_appliance",
        "arg_types": ["household_object", "furniture_appliance"],
    },
    {"name": "stack_objects", "arg_types": ["household_object", "household_object"]},
    {"name": "unstack_objects", "arg_types": ["household_object", "household_object"]},
    {"name": "open_a_furniture_piece_or_an_appliance", "arg_types": ["furniture_appliance"]},
    {"name": "close_a_furniture_piece_or_an_appliance", "arg_types": ["furniture_appliance"]},
    {"name": "toggle_a_small_appliance_on", "arg_types": ["household_object"]},
    {"name": "toggle_a_small_appliance_off", "arg_types": ["household_object"]},
    {"name": "slice_objects", "arg_types": ["household_object", "household_object"]},
    {"name": "heat_food_with_a_microwave", "arg_types": ["household_object", "furniture_appliance"]},
    {"name": "heat_food_with_pan", "arg_types": ["household_object", "household_object"]},
    {
        "name": "transfer_food_from_one_small_receptacle_to_another",
        "arg_types": ["household_object", "household_object", "household_object"],
    },
    {"name": "put_an_object_onto_or_into_a_small_receptacle", "arg_types": ["household_object", "household_object"]},
    {"name": "pick_up_an_object_on_or_in_a_small_receptacle", "arg_types": ["household_object", "household_object"]},
    {"name": "open_a_small_receptacle", "arg_types": ["household_object"]},
    {"name": "close_a_small_receptacle", "arg_types": ["household_object"]},
    {"name": "mash_food_with_a_blender", "arg_types": ["household_object", "furniture_appliance"]},
    {"name": "wash_an_object", "arg_types": ["household_object"]},
    {"name": "wipe_a_surface", "arg_types": ["furniture_appliance", "household_object"]},
    {"name": "vacuum_a_carpet", "arg_types": ["household_object", "furniture_appliance"]},
    {"name": "empty_a_vacuum_cleaner", "arg_types": ["household_object", "furniture_appliance"]},
]

action_predicates = {
    action["name"]: {
        "name": action["name"],
        "arity": len(action["arg_types"]),
        "var_types": [t for t in action["arg_types"]],
    }
    for action in action_json
}

import ast

data_dir = root_dir / "data/household"
domain_skeleton = PDDLDomain.loads((data_dir / "domain_skeleton.json").read_text())
handover_dir = root_dir / "../../3rdparty/predicators/t-data"
pddl_preds = {
    p.definition.name: {
        "name": p.definition.name,
        "arity": len(p.definition.terms),
        "var_types": [list(t.type_tags)[0] for t in p.definition.terms],
    }
    for p in domain_skeleton.predicates
}

from pddl.logic.base import Not


def parse_preds(ps):
    lits = []
    for p in ps:
        if isinstance(p, Not):
            pass
        else:
            lits.append({"predicate_name": p.name, "variables": [p.name for p in p.terms]})
    return lits


def parse_action(a):
    a = ast.parse(a).body[0].value

    predicate = a.func.id
    variables = [arg.n for arg in a.args]

    return {"action_pred_name": predicate, "variables": variables}


episode = []
for t in transitions:
    state = parse_preds(t["state"])
    action = parse_action(t["action"])
    next_state = parse_preds(t["next_state"])

    data = (state, action, next_state, None)
    episode.append(data)

(handover_dir / "preds.json").write_text(json.dumps(pddl_preds, indent=2))
(handover_dir / "action-preds.json").write_text(json.dumps(action_predicates, indent=2))
(handover_dir / "episode.json").write_text(json.dumps(episode, indent=2))
