# thor_to_abw_state.py
# Minimal translator: AI2-THOR event.metadata  -> ABW state.json (slim v0.1)

from typing import Dict, Any, List, Optional, Tuple
import math
import re
from collections import defaultdict
import copy

# ----------------------------
# Tunables
# ----------------------------
NEAR_DISTANCE_M = 1.5    # 与 THOR 交互半径一致
GRID_RESOLUTION_M = 0.25 # 网格量化，仅做粗位姿记录
MAX_EXTRA_BY_TYPE = 2    # 瘦身后每个关键类型最多额外保留几个对象

# 仅对这些类型填充 open/toggle 等功能位
APPLY_OPEN = {"Fridge", "Cabinet", "Drawer", "Microwave", "Toaster", "Window"}
APPLY_TOGGLE = {"LightSwitch", "StoveKnob", "CoffeeMachine", "Microwave", "Toaster", "Faucet"}
APPLY_CLEAN  = {"Bowl", "Mug", "Plate", "Pot", "Pan", "SinkBasin", "DishSponge"}  # 可选：MVP可不填
APPLY_SLICED = {"Bread", "Tomato", "Lettuce", "Apple"}                             # 可选：MVP可不填

# 关键“容器/表面”类型，瘦身时优先保留
KEY_TYPES = {"Drawer", "Cabinet", "Fridge", "CounterTop", "Table", "Shelf", "Sink", "SinkBasin"}

# 常见动作别名（如果你需要在外部做解析）
ACTION_ALIASES = {
    "OpenObject": "Open",
    "PickupObject": "Pickup",
    "PutInObject": "PutIn",
    "PutObject": "PutOn",
    "ToggleObject": "Toggle",
    "Rotate": "Rotate",
    "NavigateTo": "NavigateTo",
}

# ----------------------------
# Utils
# ----------------------------
def _dist_xz(p: Dict[str, float], q: Dict[str, float]) -> float:
    return math.sqrt((p.get("x", 0.0) - q.get("x", 0.0))**2 +
                     (p.get("z", 0.0) - q.get("z", 0.0))**2)

def _to_grid_xz(p: Dict[str, float]) -> Dict[str, int]:
    return {
        "x": int(round(p.get("x", 0.0) / GRID_RESOLUTION_M)),
        "y": int(round(p.get("z", 0.0) / GRID_RESOLUTION_M)),
    }

def _is_near(obj_pos: Dict[str, float], agent_pos: Dict[str, float]) -> bool:
    return _dist_xz(obj_pos, agent_pos) <= NEAR_DISTANCE_M

def _extract_agent_pose(agent_meta: Dict[str, Any]) -> Tuple[Dict[str, float], float]:
    pos = agent_meta.get("position") or agent_meta.get("cameraPosition") or {"x":0.0, "z":0.0}
    rot = agent_meta.get("rotation") or {}
    rot_y = float(rot.get("y", 0.0))
    return {"x": float(pos.get("x", 0.0)), "z": float(pos.get("z", 0.0))}, rot_y

def _first_inventory_id(inv: Any) -> Optional[str]:
    if isinstance(inv, list) and len(inv) > 0:
        oid = inv[0].get("objectId") or inv[0].get("name")
        return str(oid) if oid else None
    return None

def _collect_agents_from_metadata(meta: Dict[str, Any], num_agents: int) -> List[Dict[str, Any]]:
    # 兼容单/多智能体
    if isinstance(meta.get("agents"), list) and meta["agents"]:
        agents_meta = meta["agents"]
    elif isinstance(meta.get("agent"), dict):
        agents_meta = [meta["agent"]]
    else:
        agents_meta = []
    # 如果元数据少于 num_agents，后续会兜底补空
    return agents_meta

def _normalize_action_string(a: str) -> str:
    m = re.match(r"\s*([A-Za-z_][A-Za-z0-9_]*)\s*\((.*)\)\s*$", str(a))
    if not m:
        return str(a).strip()
    verb, args = m.group(1), m.group(2)
    verb = ACTION_ALIASES.get(verb, verb)
    return f"{verb}({args})"

def _actions_referenced_ids(joint_actions: Optional[Dict[str, str]]) -> List[str]:
    """提取 joint_actions 中括号里的各个 token，便于瘦身优先保留相关对象"""
    refs = []
    if not isinstance(joint_actions, dict):
        return refs
    for s in joint_actions.values():
        s = _normalize_action_string(s)
        m = re.match(r".*\((.*)\)", s)
        if m:
            args = [p.strip() for p in m.group(1).split(",") if p.strip()]
            refs.extend(args)
    return refs

# ----------------------------
# Core translator
# ----------------------------
def build_abw_state(
    thor_event_metadata: Dict[str, Any],
    agents_info: Dict[str, Dict[str, Any]],
    task_spec: Dict[str, Any],
    *,
    joint_actions: Optional[Dict[str, str]] = None,
    slim: bool = True
) -> Dict[str, Any]:
    """
    将 THOR metadata + agents_info + task_spec 转为最小 ABW state.json。

    参数
    - thor_event_metadata: event.metadata（来自 env.controller.last_event.metadata）
    - agents_info: { "agent_0": {"position": {"x":..,"z":..}, "rotation": float, "holding": <str or None>}, ... }
    - task_spec:  { "subtasks":[{"id","description",...}, ...], "steps_remaining": int, ... }
    - joint_actions: 可选，用于瘦身时优先保留相关对象
    - slim: 是否进行对象瘦身（保留相关/可见/近邻/关键类型少量样本）

    返回
    - state_json（精简 schema）
    """
    meta = thor_event_metadata or {}
    scene_id = meta.get("sceneName", "")

    # 1) Agents
    state_agents: Dict[str, Any] = {}
    # 若 agents_info 缺失，则尽力从 metadata 兜底
    meta_agents = _collect_agents_from_metadata(meta, num_agents=len(agents_info))
    for i in range(max(len(agents_info), len(meta_agents))):
        aid = f"agent_{i}"
        if aid in agents_info:
            ainfo = agents_info[aid]
            pos = ainfo.get("position", {"x":0.0,"z":0.0})
            rot = float(ainfo.get("rotation", 0.0))
            holding = ainfo.get("holding")
        else:
            am = meta_agents[i] if i < len(meta_agents) else {}
            pos, rot = _extract_agent_pose(am)
            holding = _first_inventory_id(am.get("inventoryObjects", [])) or None

        state_agents[aid] = {
            "grid": _to_grid_xz(pos),
            "pose": rot,
            "holding": holding if holding is not None else "null",
        }

    # 2) Objects (minimal fields)
    objects_out: Dict[str, Any] = {}
    objects_meta = meta.get("objects", [])
    # 预取每个 agent 的连续坐标用于 near 判定
    agent_float_pos = {}
    for aid, a in state_agents.items():
        # 反量化拿不到原坐标，这里优先从 agents_info 拿连续坐标
        if aid in agents_info and "position" in agents_info[aid]:
            agent_float_pos[aid] = agents_info[aid]["position"]
        else:
            # 用网格反推一个近似坐标（仅用于 near 粗判）
            g = a["grid"]
            agent_float_pos[aid] = {"x": g["x"] * GRID_RESOLUTION_M, "z": g["y"] * GRID_RESOLUTION_M}

    for obj in objects_meta:
        oid = obj.get("objectId") or obj.get("name") or obj.get("id")
        otype = obj.get("objectType") or obj.get("type", "Unknown")
        if not oid or not otype:
            continue

        # 可见性
        visible = bool(obj.get("visible", False))
        # 位置（xz）
        opos = obj.get("position") or {"x":0.0, "z":0.0}
        # per-agent near
        near_map = { aid: _is_near(opos, agent_float_pos[aid]) for aid in state_agents.keys() }

        # 最小对象条目
        entry = {
            "type": otype,
            "visible": visible,
            "near": near_map,
        }

        # 仅对适用类型，按需填功能位
        if otype in APPLY_OPEN and "isOpen" in obj:
            entry["open"] = bool(obj["isOpen"])
        if otype in APPLY_TOGGLE and "isToggled" in obj:
            entry["toggle"] = "on" if obj["isToggled"] else "off"
        if otype in APPLY_CLEAN and "isDirty" in obj:
            entry["clean"] = not bool(obj["isDirty"])
        if otype in APPLY_SLICED and "isSliced" in obj:
            entry["sliced"] = bool(obj["isSliced"])

        # 若 metadata 暴露 receptacle/parent 容器关系，可在此填 in/on（MVP可省略）
        # parentReceptacles 示例（不同版本字段名可能不同）：
        # parents = obj.get("parentReceptacles") or obj.get("parentReceptacleIds")
        # if isinstance(parents, list) and parents:
        #     # 粗略填：若父是容器类就 in=父，否则 on=父
        #     parent_id = parents[0]
        #     if any(t in parent_id for t in ("Drawer","Cabinet","Fridge","Box","Microwave","Toaster")):
            #     entry["in"] = parent_id
        #     else:
            #     entry["on"] = parent_id

        objects_out[str(oid)] = entry

    # 3) Constraints（占位/外部注入）
    constraints = {
        "occupied": [],      # e.g., "fridge_1_handle" / "countertop_1_surface"
        "blocked_paths": []  # e.g., {"agent":"agent_0","target":"Drawer_1","reason":"chair block"}
    }

    # 4) Task（直接照搬传入的 Actor 当拍子任务）
    task = {
        "subtasks": task_spec.get("subtasks", []),
        "steps_remaining": int(task_spec.get("steps_remaining", 0)),
    }

    state = {
        "agents": state_agents,
        "objects": objects_out,
        "constraints": constraints,
        "task": task,
        "metadata": {
            "scene_id": scene_id
        }
    }

    # 5) 瘦身（降成本+降噪）
    if slim:
        state = _slim_state(state, joint_actions=joint_actions)

    return state

# ----------------------------
# Slimming: keep only relevant/visible/near/key types
# ----------------------------
def _slim_state(state: Dict[str, Any], joint_actions: Optional[Dict[str, str]] = None) -> Dict[str, Any]:
    if "objects" not in state:
        return state
    objs = state["objects"]
    if not isinstance(objs, dict):
        return state

    # a) 动作引用对象（括号内 token）
    refs = set(_actions_referenced_ids(joint_actions))

    # b) 可见对象
    visible_ids = {oid for oid, o in objs.items() if o.get("visible", False)}
    # c) 任一 agent 近邻对象
    near_ids = {oid for oid, o in objs.items() if any(o.get("near", {}).values())}

    keep_ids = set()
    # 先尽量直接留：动作引用（模糊匹配）、可见、近邻
    for oid in objs.keys():
        if oid in visible_ids or oid in near_ids:
            keep_ids.add(oid)
            continue
        # 对 refs 做模糊包含匹配（支持 Drawer_1 -> Drawer|...ID）
        for r in refs:
            if r and (r == oid or r in oid):
                keep_ids.add(oid)
                break

    # d) 关键类型各留最多 MAX_EXTRA_BY_TYPE 个（比如抽屉/柜子/台面/冰箱）
    # 避免把动作附近的容器类都删了
    by_type = defaultdict(list)
    for oid, o in objs.items():
        by_type[o.get("type","Unknown")].append(oid)

    for t in KEY_TYPES:
        if t in by_type:
            # 优先保留“可见或近邻”的，数量不足再补充
            cands = by_type[t]
            pri = [oid for oid in cands if (oid in visible_ids or oid in near_ids)]
            rest = [oid for oid in cands if oid not in pri]
            chosen = pri[:MAX_EXTRA_BY_TYPE] + rest[:max(0, MAX_EXTRA_BY_TYPE - len(pri))]
            keep_ids.update(chosen)

    # e) 构造瘦身后的 objects
    new_objs = {oid: objs[oid] for oid in keep_ids}
    new_state = copy.deepcopy(state)
    new_state["objects"] = new_objs
    return new_state

# ----------------------------
# Example (for quick local test)
# ----------------------------
if __name__ == "__main__":
    # 伪造最小 metadata/agents/task 测试
    fake_meta = {
        "sceneName": "FloorPlan1_physics",
        "objects": [
            {"objectId":"Drawer|+00.95|+00.83|-02.20", "objectType":"Drawer", "visible":True,  "position":{"x":0.9,"z":-2.2}, "isOpen":False},
            {"objectId":"CounterTop|-00.08|+01.15|00.00", "objectType":"CounterTop", "visible":True, "position":{"x":0.0,"z":0.0}},
            {"objectId":"Apple|-00.47|+01.15|+00.48", "objectType":"Apple", "visible":False, "position":{"x":-0.5,"z":0.5}},
        ],
        "agent": {
            "position":{"x":0.0,"z":0.0},
            "rotation":{"y":0.0},
            "inventoryObjects":[]
        }
    }
    agents_info = {
        "agent_0": {"position":{"x": 1.5, "z": -1.5}, "rotation":270.0, "holding": None},
        "agent_1": {"position":{"x": 0.0, "z": 0.0},   "rotation":  0.0, "holding": None},
    }
    task_spec = {
        "subtasks": [
            {"id":"active_0","description":"Open(Drawer_1)"},
            {"id":"active_1","description":"Rotate(Right)"}
        ],
        "steps_remaining": 40
    }
    joint_actions = {"agent_0":"OpenObject(Drawer_1)", "agent_1":"Rotate(Right)"}

    state = build_abw_state(fake_meta, agents_info, task_spec, joint_actions=joint_actions, slim=True)
    import json; print(json.dumps(state, indent=2))
