# import os
# import json
# import argparse
# import numpy as np
# import pandas as pd
# import torch
# import open_clip

# from evaluation_3d_seg import compute_metrics


# def class_names(path, dataset):
#     class_names = []
#     class_indices = []
#     if dataset == "Replica":
#         with open(path, "r") as file:
#             data = json.load(file)
#             for i in range(len(data["classes"])):
#                 class_names.append("a photo of " + data["classes"][i]["name"])
#                 class_indices.append(int(data["classes"][i]["id"]) - 1)
#     elif dataset == "ScanNet":
#         with open(path, "r") as file:
#             data = json.load(file)
#             for key in data.keys():
#                 class_indices.append(int(key))
#                 if data[key] != "picture":
#                     class_names.append("a photo of " + data[key])
#                 else:
#                     class_names.append("a photo of a picture on wall")
#     return class_names, class_indices


# def load_clip_text_embeddings(model, clip_model_name, class_texts, device):
#     # For compatibility with your original code, keep tokenizer = ViT-H-14
#     tokenizer = open_clip.get_tokenizer("ViT-H-14")
#     tokens = tokenizer(class_texts).to(device)
#     model.eval()
#     model.to(device)
#     with torch.no_grad():
#         text_emb = model.encode_text(tokens).float()  # [C, D]
#     return text_emb  # torch.Tensor


# def load_core3d_embeddings(emb_path):
#     """
#     Load per-object component embeddings from *_ids_to_embeddings_ctx.json.
#     Returns:
#       obj_ids: list of int object ids
#       E_s, E_l, E_h, E_hide, E_mask: [N_obj, D] np arrays
#     """
#     with open(emb_path, "r") as f:
#         data = json.load(f)

#     obj_ids = sorted(list(map(int, data.keys())))
#     E_s_list = []
#     E_l_list = []
#     E_h_list = []
#     E_hide_list = []
#     E_mask_list = []

#     for oid in obj_ids:
#         entry = data[str(oid)]
#         E_s_list.append(np.array(entry["E_s"], dtype=np.float32))
#         E_l_list.append(np.array(entry["E_l"], dtype=np.float32))
#         E_h_list.append(np.array(entry["E_h"], dtype=np.float32))
#         E_hide_list.append(np.array(entry["E_hide"], dtype=np.float32))
#         E_mask_list.append(np.array(entry["E_mask"], dtype=np.float32))

#     E_s = np.stack(E_s_list, axis=0)
#     E_l = np.stack(E_l_list, axis=0)
#     E_h = np.stack(E_h_list, axis=0)
#     E_hide = np.stack(E_hide_list, axis=0)
#     E_mask = np.stack(E_mask_list, axis=0)

#     return obj_ids, E_s, E_l, E_h, E_hide, E_mask


# def main():
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--path", type=str, default="")
#     parser.add_argument("--dataset", type=str, choices=["Replica", "ScanNet"])
#     parser.add_argument("--scene", type=str, required=True)
#     parser.add_argument("--clip_model", type=str, default="EVA02-L-14-336")
#     parser.add_argument(
#         "--param",
#         type=str,
#         default="all",
#         choices=["all", "alpha_h", "alpha_l", "alpha_o", "alpha_m"],
#         help="Which weight to sweep in sensitivity analysis.",
#     )
#     parser.add_argument(
#         "--factors",
#         type=float,
#         nargs="+",
#         default=[0.5, 0.75, 1.0, 1.25, 1.5],
#         help="Multiplicative factors around default weights.",
#     )
#     args = parser.parse_args()

#     base_path = args.path
#     dataset = args.dataset
#     scene = args.scene
#     clip_model_name = args.clip_model
#     param_mode = args.param
#     factors = args.factors

#     device = "cuda" if torch.cuda.is_available() else "cpu"

#     # ----------------------------------------------------------------------
#     # 1. Load CLIP model & weights
#     # ----------------------------------------------------------------------
#     model, _, _ = open_clip.create_model_and_transforms(clip_model_name, pretrained=None)
#     ckpt_path = os.path.join(base_path, "models", "open_clip_pytorch_model.bin")
#     state = torch.load(ckpt_path, map_location="cpu")
#     missing, unexpected = model.load_state_dict(state, strict=False)
#     print(
#         f"[CLIP] Loaded weights from {ckpt_path}. "
#         f"Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}"
#     )

#     # ----------------------------------------------------------------------
#     # 2. Load points-to-object map & component embeddings
#     # ----------------------------------------------------------------------
#     emb_dir = os.path.join(base_path, "embeddings")
#     pts_file = os.path.join(emb_dir, f"{scene}_points_to_ids_ctx.csv")
#     ctx_file = os.path.join(emb_dir, f"{scene}_ids_to_embeddings_ctx.json")

#     print(f"[LOAD] Reading points-to-ids from: {pts_file}")
#     df_pts = pd.read_csv(pts_file)  # columns: x, y, z, Object id

#     print(f"[LOAD] Reading ctx embeddings from: {ctx_file}")
#     obj_ids, E_s_np, E_l_np, E_h_np, E_hide_np, E_mask_np = load_core3d_embeddings(ctx_file)

#     # Assume object ids are contiguous 0..N_obj-1 as in your pipeline
#     N_obj, D = E_s_np.shape
#     print(f"[INFO] Loaded {N_obj} objects with embedding dim {D}")

#     # Map object-id -> index
#     max_obj_id = max(obj_ids)
#     if max_obj_id + 1 != N_obj:
#         print("[WARN] Object ids are not exactly 0..N_obj-1; using a mapping.")
#         oid_to_idx = {oid: idx for idx, oid in enumerate(obj_ids)}
#     else:
#         oid_to_idx = None  # identity mapping

#     # Convert component embeddings to torch
#     E_s = torch.tensor(E_s_np, device=device)       # [N_obj, D]
#     E_l = torch.tensor(E_l_np, device=device)
#     E_h = torch.tensor(E_h_np, device=device)
#     E_hide = torch.tensor(E_hide_np, device=device)
#     E_mask = torch.tensor(E_mask_np, device=device)

#     # ----------------------------------------------------------------------
#     # 3. Load class names & text embeddings
#     # ----------------------------------------------------------------------
#     if dataset == "Replica":
#         class_file = os.path.join(base_path, "dataset", "Replica-data", "info_semantic.json")
#     else:
#         class_file = os.path.join(base_path, "dataset", "ScanNet", "classes.json")

#     classes, class_indices = class_names(class_file, dataset)
#     print(f"[LOAD] {len(classes)} classes from {class_file}")

#     class_embeddings = load_clip_text_embeddings(model, clip_model_name, classes, device)
#     # shape: [C, D]
#     C = class_embeddings.shape[0]

#     # ----------------------------------------------------------------------
#     # 4. Load default alphas from YAML config
#     # ----------------------------------------------------------------------
#     import yaml

#     if dataset == "Replica":
#         config_file = os.path.join(base_path, "core_configs", "config_Replica.yaml")
#     else:
#         config_file = os.path.join(base_path, "core_configs", "config_ScanNet.yaml")

#     with open(config_file, "r") as f:
#         config = yaml.safe_load(f)

#     alpha_h0 = float(config["alpha_h"])
#     alpha_l0 = float(config["alpha_l"])
#     alpha_o0 = float(config["alpha_o"])
#     alpha_m0 = float(config["alpha_m"])
#     print(
#         f"[CONFIG] Default alphas: "
#         f"alpha_h={alpha_h0}, alpha_l={alpha_l0}, alpha_o={alpha_o0}, alpha_m={alpha_m0}"
#     )

#     # Decide which params to sweep
#     if param_mode == "all":
#         param_list = ["alpha_h", "alpha_l", "alpha_o", "alpha_m"]
#     else:
#         param_list = [param_mode]

#     # ----------------------------------------------------------------------
#     # 5. Load ground truth for evaluation (same format as evaluation_3d_seg.py)
#     # ----------------------------------------------------------------------
#     gt_dir = os.path.join(base_path, "ground_truth")
#     gt_file = os.path.join(gt_dir, f"{scene}_ground_truth.csv")
#     print(f"[LOAD] Reading ground truth from: {gt_file}")
#     df_gt = pd.read_csv(gt_file)

#     gt_points = torch.tensor(df_gt[["x", "y", "z"]].to_numpy(), device=device)
#     gt_labels = torch.tensor(df_gt["label"].to_numpy(), device=device).long() - 1

#     if dataset == "Replica":
#         eval_classes = list(np.arange(101))
#         class_num = 101
#     else:
#         eval_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 23, 27, 32, 33, 35, 38]
#         class_num = 40

#     # Pre-build per-point object index tensor
#     obj_ids_points = df_pts["Object id"].to_numpy()
#     if oid_to_idx is None:
#         obj_idx_points = torch.tensor(obj_ids_points, device=device, dtype=torch.long)
#     else:
#         mapped = [oid_to_idx[int(o)] for o in obj_ids_points]
#         obj_idx_points = torch.tensor(mapped, device=device, dtype=torch.long)

#     pred_points = torch.tensor(df_pts[["x", "y", "z"]].to_numpy(), device=device)

#     # Prepare output dirs
#     pred_dir = os.path.join(base_path, "predicted_labels1")
#     os.makedirs(pred_dir, exist_ok=True)
#     sens_dir = os.path.join(base_path, "sensitivity")
#     os.makedirs(sens_dir, exist_ok=True)

#     # For mapping from class index (0..C-1) to label id used in GT
#     class_indices_arr = torch.tensor(class_indices, device=device, dtype=torch.long)

#     # ----------------------------------------------------------------------
#     # 6. Sweep each parameter and compute metrics
#     # ----------------------------------------------------------------------
#     results = []  # for CSV summary

#     for param_name in param_list:
#         print(f"\n=== Sweeping {param_name} over factors {factors} ===")
#         for factor in factors:
#             # Start from defaults
#             alpha_h = alpha_h0
#             alpha_l = alpha_l0
#             alpha_o = alpha_o0
#             alpha_m = alpha_m0

#             if param_name == "alpha_h":
#                 alpha_h = alpha_h0 * factor
#             elif param_name == "alpha_l":
#                 alpha_l = alpha_l0 * factor
#             elif param_name == "alpha_o":
#                 alpha_o = alpha_o0 * factor
#             elif param_name == "alpha_m":
#                 alpha_m = alpha_m0 * factor

#             print(
#                 f"[SWEEP] {param_name} factor={factor:.2f} -> "
#                 f"alpha_h={alpha_h:.4f}, alpha_l={alpha_l:.4f}, "
#                 f"alpha_o={alpha_o:.4f}, alpha_m={alpha_m:.4f}"
#             )

#             # Combine embeddings: E = alpha_h*E_h + alpha_l*E_l + E_s - alpha_o*E_hide + alpha_m*E_mask
#             E_comb = (
#                 alpha_h * E_h
#                 + alpha_l * E_l
#                 + E_s
#                 - alpha_o * E_hide
#                 + alpha_m * E_mask
#             )  # [N_obj, D]

#             # Compute class logits: [N_obj, C]
#             logits = E_comb @ class_embeddings.T  # [N_obj, C]

#             # Argmax per object
#             top_idx = torch.argmax(logits, dim=1)  # [N_obj]
#             # Map to dataset label id via class_indices
#             labels_obj = class_indices_arr[top_idx]  # [N_obj]

#             # Map per-object labels to per-point labels
#             point_labels = labels_obj[obj_idx_points]  # [N_points]

#             # Save predicted labels CSV (for reproducibility)
#             tag = f"{param_name}_x{int(round(factor * 100))}"
#             out_csv = os.path.join(pred_dir, f"{scene}_predicted_labels_{tag}.csv")
#             df_out = df_pts.copy()
#             df_out["labels"] = point_labels.cpu().numpy().astype(int)
#             df_out.to_csv(out_csv, index=False)
#             print(f"[SAVE] Predicted labels saved to: {out_csv}")

#             # Compute 3D metrics using same logic as evaluation_3d_seg.py
#             pred_labels_for_eval = (point_labels - 1).unsqueeze(1)  # match evaluation_3d_seg
#             metrics = compute_metrics(
#                 pred_points,
#                 pred_labels_for_eval,
#                 gt_points,
#                 gt_labels,
#                 class_num,
#                 threshold=0.25,
#                 eval_classes=eval_classes,
#                 batch_size=1000,
#             )
#             print(
#                 f"[METRICS] {param_name} factor={factor:.2f} -> "
#                 f"mIoU={metrics['mIoU']:.4f}, "
#                 f"mAcc={metrics['mAcc']:.4f}, "
#                 f"F-IoU={metrics['F-IoU']:.4f}"
#             )

#             results.append(
#                 {
#                     "scene": scene,
#                     "dataset": dataset,
#                     "param": param_name,
#                     "factor": factor,
#                     "alpha_h": alpha_h,
#                     "alpha_l": alpha_l,
#                     "alpha_o": alpha_o,
#                     "alpha_m": alpha_m,
#                     "mIoU": metrics["mIoU"],
#                     "mAcc": metrics["mAcc"],
#                     "F-IoU": metrics["F-IoU"],
#                 }
#             )

#     # ----------------------------------------------------------------------
#     # 7. Save sensitivity summary for plotting
#     # ----------------------------------------------------------------------
#     df_res = pd.DataFrame(results)
#     out_sens = os.path.join(sens_dir, f"{scene}_ctx_zero.csv")
#     df_res.to_csv(out_sens, index=False)
#     print(f"\n[SUMMARY] Sensitivity results saved to: {out_sens}")


# if __name__ == "__main__":
#     main()


import os
import json
import argparse
import numpy as np
import pandas as pd
import torch
import open_clip
import yaml

from evaluation_3d_seg import compute_metrics


# ---------------------- Class name loader ---------------------- #
def class_names(path, dataset):
    """
    dataset: "Replica" or "ScanNet" (case-insensitive)
    For ScanNet, expects classes.json of the form:
        {
          "1": "wall",
          "2": "floor",
          ...
        }
    """
    ds = dataset.lower()
    class_names = []
    class_indices = []

    with open(path, "r") as file:
        data = json.load(file)

    if ds == "replica":
        # data["classes"] is a list of { "id": int, "name": str }
        for i in range(len(data["classes"])):
            cid = int(data["classes"][i]["id"])  # original id
            cname = data["classes"][i]["name"]
            class_names.append("a photo of " + cname)
            # if your config/GT is 1..101, this makes them 0..100
            class_indices.append(cid - 1)

    elif ds == "scannet":
        # keys are strings: "1", "2", ...
        # sort by numeric id so names & indices line up
        for key in sorted(data.keys(), key=lambda x: int(x)):
            cid = int(key)
            cname = data[key]
            class_indices.append(cid)  # keep raw label id
            if cname != "picture":
                class_names.append("a photo of " + cname)
            else:
                class_names.append("a photo of a picture on wall")

    else:
        raise ValueError(f"Unknown dataset: {dataset}")

    return class_names, class_indices


# ---------------------- CLIP text embeddings ---------------------- #
def load_clip_text_embeddings(model, clip_model_name, class_texts, device):
    # For compatibility with your original code, keep tokenizer = ViT-H-14
    tokenizer = open_clip.get_tokenizer("ViT-H-14")
    tokens = tokenizer(class_texts).to(device)
    model.eval()
    model.to(device)
    with torch.no_grad():
        text_emb = model.encode_text(tokens).float()  # [C, D]
    return text_emb  # torch.Tensor


# ---------------------- Core3D embeddings loader ---------------------- #
def load_core3d_embeddings(emb_path):
    """
    Load per-object component embeddings from *_ids_to_embeddings_ctx.json.
    Returns:
      obj_ids: list of int object ids
      E_s, E_l, E_h, E_hide, E_mask: [N_obj, D] np arrays
    """
    with open(emb_path, "r") as f:
        data = json.load(f)

    obj_ids = sorted(list(map(int, data.keys())))
    E_s_list = []
    E_l_list = []
    E_h_list = []
    E_hide_list = []
    E_mask_list = []

    for oid in obj_ids:
        entry = data[str(oid)]
        E_s_list.append(np.array(entry["E_s"], dtype=np.float32))
        E_l_list.append(np.array(entry["E_l"], dtype=np.float32))
        E_h_list.append(np.array(entry["E_h"], dtype=np.float32))
        E_hide_list.append(np.array(entry["E_hide"], dtype=np.float32))
        E_mask_list.append(np.array(entry["E_mask"], dtype=np.float32))

    E_s = np.stack(E_s_list, axis=0)
    E_l = np.stack(E_l_list, axis=0)
    E_h = np.stack(E_h_list, axis=0)
    E_hide = np.stack(E_hide_list, axis=0)
    E_mask = np.stack(E_mask_list, axis=0)

    return obj_ids, E_s, E_l, E_h, E_hide, E_mask


# ---------------------- Main ---------------------- #
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str, default="")
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["Replica", "ScanNet", "scannet"],
        help="Dataset name (case-insensitive for ScanNet).",
    )
    parser.add_argument("--scene", type=str, required=True)
    parser.add_argument("--clip_model", type=str, default="EVA02-L-14-336")
    parser.add_argument(
        "--param",
        type=str,
        default="all",
        choices=["all", "alpha_h", "alpha_l", "alpha_o", "alpha_m"],
        help="Which weight to sweep in sensitivity analysis.",
    )
    parser.add_argument(
        "--factors",
        type=float,
        nargs="+",
        default=[0.5, 0.75, 1.0, 1.25, 1.5],
        help="Multiplicative factors around default weights.",
    )
    args = parser.parse_args()

    base_path = args.path
    dataset = args.dataset
    ds = dataset.lower()
    scene = args.scene
    clip_model_name = args.clip_model
    param_mode = args.param
    factors = args.factors

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # ----------------------------------------------------------------------
    # 1. Load CLIP model & weights
    # ----------------------------------------------------------------------
    model, _, _ = open_clip.create_model_and_transforms(
        clip_model_name, pretrained=None
    )
    ckpt_path = os.path.join(base_path, "models", "open_clip_pytorch_model.bin")
    state = torch.load(ckpt_path, map_location="cpu")
    missing, unexpected = model.load_state_dict(state, strict=False)
    print(
        f"[CLIP] Loaded weights from {ckpt_path}. "
        f"Missing keys: {len(missing)}, unexpected keys: {len(unexpected)}"
    )

    # ----------------------------------------------------------------------
    # 2. Load points-to-object map & component embeddings
    # ----------------------------------------------------------------------
    emb_dir = os.path.join(base_path, "embeddings")
    pts_file = os.path.join(emb_dir, f"{scene}_points_to_ids_ctx_scannet.csv")
    ctx_file = os.path.join(emb_dir, f"{scene}_ids_to_embeddings_ctx_scannet.json")

    print(f"[LOAD] Reading points-to-ids from: {pts_file}")
    df_pts = pd.read_csv(pts_file)  # columns: x, y, z, Object id

    print(f"[LOAD] Reading ctx embeddings from: {ctx_file}")
    obj_ids, E_s_np, E_l_np, E_h_np, E_hide_np, E_mask_np = load_core3d_embeddings(
        ctx_file
    )

    N_obj, D = E_s_np.shape
    print(f"[INFO] Loaded {N_obj} objects with embedding dim {D}")

    # Map object-id -> index
    max_obj_id = max(obj_ids)
    if max_obj_id + 1 != N_obj:
        print("[WARN] Object ids are not exactly 0..N_obj-1; using a mapping.")
        oid_to_idx = {oid: idx for idx, oid in enumerate(obj_ids)}
    else:
        oid_to_idx = None  # identity mapping

    # Convert component embeddings to torch
    E_s = torch.tensor(E_s_np, device=device)       # [N_obj, D]
    E_l = torch.tensor(E_l_np, device=device)
    E_h = torch.tensor(E_h_np, device=device)
    E_hide = torch.tensor(E_hide_np, device=device)
    E_mask = torch.tensor(E_mask_np, device=device)

    # ----------------------------------------------------------------------
    # 3. Load class names & text embeddings
    # ----------------------------------------------------------------------
    if ds == "replica":
        class_file = os.path.join(
            base_path, "dataset", "Replica-data", "info_semantic.json"
        )
    else:  # ScanNet
        class_file = os.path.join(base_path, "dataset", "ScanNet", "classes.json")

    classes, class_indices = class_names(class_file, dataset)
    print(f"[LOAD] {len(classes)} classes from {class_file}")

    class_embeddings = load_clip_text_embeddings(
        model, clip_model_name, classes, device
    )  # [C, D]
    C = class_embeddings.shape[0]
    print(f"[INFO] Text embedding shape: {C} classes x {class_embeddings.shape[1]} dim")

    # ----------------------------------------------------------------------
    # 4. Load default alphas from YAML config
    # ----------------------------------------------------------------------
    if ds == "replica":
        config_file = os.path.join(base_path, "core_configs", "config_Replica.yaml")
    else:
        config_file = os.path.join(base_path, "core_configs", "config_ScanNet.yaml")

    with open(config_file, "r") as f:
        config = yaml.safe_load(f)

    alpha_h0 = float(config["alpha_h"])
    alpha_l0 = float(config["alpha_l"])
    alpha_o0 = float(config["alpha_o"])
    alpha_m0 = float(config["alpha_m"])
    print(
        f"[CONFIG] Default alphas: "
        f"alpha_h={alpha_h0}, alpha_l={alpha_l0}, "
        f"alpha_o={alpha_o0}, alpha_m={alpha_m0}"
    )

    if param_mode == "all":
        param_list = ["alpha_h", "alpha_l", "alpha_o", "alpha_m"]
    else:
        param_list = [param_mode]

    # ----------------------------------------------------------------------
    # 5. Load ground truth for evaluation
    # ----------------------------------------------------------------------
    gt_dir = os.path.join(base_path, "ground_truth")
    gt_file = os.path.join(gt_dir, f"{scene}_ground_truth.csv")
    print(f"[LOAD] Reading ground truth from: {gt_file}")
    df_gt = pd.read_csv(gt_file)

    gt_points = torch.tensor(df_gt[["x", "y", "z"]].to_numpy(), device=device)

    # NOTE: keep your original -1 offset; this assumes GT labels start from 1
    gt_labels = torch.tensor(df_gt["label"].to_numpy(), device=device).long() - 1

    if ds == "replica":
        eval_classes = list(np.arange(101))
        class_num = 101
    else:
        # Your ScanNet evaluation subset (after -1 shift)
        eval_classes = [
            0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
            10, 11, 13, 15, 23, 27, 32, 33, 35, 38
        ]
        class_num = 40

    # ----------------------------------------------------------------------
    # 6. Pre-build per-point object index tensor
    # ----------------------------------------------------------------------
    obj_ids_points = df_pts["Object id"].to_numpy()
    if oid_to_idx is None:
        obj_idx_points = torch.tensor(
            obj_ids_points, device=device, dtype=torch.long
        )
    else:
        mapped = [oid_to_idx[int(o)] for o in obj_ids_points]
        obj_idx_points = torch.tensor(mapped, device=device, dtype=torch.long)

    pred_points = torch.tensor(df_pts[["x", "y", "z"]].to_numpy(), device=device)

    # Output dirs
    pred_dir = os.path.join(base_path, "predicted_labels1")
    os.makedirs(pred_dir, exist_ok=True)
    sens_dir = os.path.join(base_path, "sensitivity")
    os.makedirs(sens_dir, exist_ok=True)

    # Map from class index (0..C-1) to dataset label id
    class_indices_arr = torch.tensor(
        class_indices, device=device, dtype=torch.long
    )

    # ----------------------------------------------------------------------
    # 7. Sweep each parameter and compute metrics
    # ----------------------------------------------------------------------
    results = []

    for param_name in param_list:
        print(f"\n=== Sweeping {param_name} over factors {factors} ===")
        for factor in factors:
            # Start from defaults
            alpha_h = alpha_h0
            alpha_l = alpha_l0
            alpha_o = alpha_o0
            alpha_m = alpha_m0

            if param_name == "alpha_h":
                alpha_h = alpha_h0 * factor
            elif param_name == "alpha_l":
                alpha_l = alpha_l0 * factor
            elif param_name == "alpha_o":
                alpha_o = alpha_o0 * factor
            elif param_name == "alpha_m":
                alpha_m = alpha_m0 * factor

            print(
                f"[SWEEP] {param_name} factor={factor:.2f} -> "
                f"alpha_h={alpha_h:.4f}, alpha_l={alpha_l:.4f}, "
                f"alpha_o={alpha_o:.4f}, alpha_m={alpha_m:.4f}"
            )

            # Combine embeddings: E = alpha_h*E_h + alpha_l*E_l + E_s - alpha_o*E_hide + alpha_m*E_mask
            E_comb = (
                alpha_h * E_h
                + alpha_l * E_l
                + E_s
                - alpha_o * E_hide
                + alpha_m * E_mask
            )  # [N_obj, D]

            # Class logits: [N_obj, C]
            logits = E_comb @ class_embeddings.T  # [N_obj, C]

            # Argmax per object
            top_idx = torch.argmax(logits, dim=1)  # [N_obj]
            labels_obj = class_indices_arr[top_idx]  # [N_obj] in dataset label space

            # Map per-object labels to per-point labels
            point_labels = labels_obj[obj_idx_points]  # [N_points]

            # Save predicted labels CSV (for reproducibility)
            tag = f"{param_name}_x{int(round(factor * 100))}"
            out_csv = os.path.join(
                pred_dir, f"{scene}_predicted_labels_{tag}.csv"
            )
            df_out = df_pts.copy()
            df_out["labels"] = point_labels.cpu().numpy().astype(int)
            df_out.to_csv(out_csv, index=False)
            print(f"[SAVE] Predicted labels saved to: {out_csv}")

            # Prepare predictions for evaluation
            # keep the same -1 shift you used in compute_metrics pipeline
            pred_labels_for_eval = (point_labels - 1).unsqueeze(1)

            metrics = compute_metrics(
                pred_points,
                pred_labels_for_eval,
                gt_points,
                gt_labels,
                class_num,
                threshold=0.25,
                eval_classes=eval_classes,
                batch_size=1000,
            )
            print(
                f"[METRICS] {param_name} factor={factor:.2f} -> "
                f"mIoU={metrics['mIoU']:.4f}, "
                f"mAcc={metrics['mAcc']:.4f}, "
                f"F-IoU={metrics['F-IoU']:.4f}"
            )

            results.append(
                {
                    "scene": scene,
                    "dataset": dataset,
                    "param": param_name,
                    "factor": factor,
                    "alpha_h": alpha_h,
                    "alpha_l": alpha_l,
                    "alpha_o": alpha_o,
                    "alpha_m": alpha_m,
                    "mIoU": metrics["mIoU"],
                    "mAcc": metrics["mAcc"],
                    "F-IoU": metrics["F-IoU"],
                }
            )

    # ----------------------------------------------------------------------
    # 8. Save sensitivity summary
    # ----------------------------------------------------------------------
    df_res = pd.DataFrame(results)
    out_sens = os.path.join(sens_dir, f"{scene}_ctx_zero.csv")
    df_res.to_csv(out_sens, index=False)
    print(f"\n[SUMMARY] Sensitivity results saved to: {out_sens}")


if __name__ == "__main__":
    main()
