import os
import json
import argparse

import numpy as np
import pandas as pd


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", type=str, default="")
    parser.add_argument("--dataset", type=str, required=True)  # kept just for consistency/logging
    parser.add_argument("--scene", type=str, required=True)
    args = parser.parse_args()

    path = args.path.rstrip("/")
    dataset = args.dataset
    scene = args.scene

    # ----- Paths to OVSeg embeddings -----
    embeddings_dir = os.path.join(path, "embeddings/Replica")
    points_csv = os.path.join(embeddings_dir, f"{scene}_points_to_ids_ovseg.csv")
    embeds_json = os.path.join(embeddings_dir, f"{scene}_ids_to_embeddings_ovseg.json")

    if not os.path.exists(points_csv):
        raise FileNotFoundError(f"Points file not found: {points_csv}")
    if not os.path.exists(embeds_json):
        raise FileNotFoundError(f"Embeddings file not found: {embeds_json}")

    # Load data
    df_points_to_id = pd.read_csv(points_csv)
    with open(embeds_json, "r") as f:
        df_ids_to_embeddings = json.load(f)

    labels = []
    scores = []  # max logit per object (confidence-like)

    # For each point: look up its Object id, then pick argmax over that object's OVSeg logits
    for i in range(len(df_points_to_id)):
        obj_id_int = int(df_points_to_id.at[i, "Object id"])
        obj_id_str = str(obj_id_int)

        if obj_id_str not in df_ids_to_embeddings:
            # Should not happen, but be defensive
            labels.append(-1)
            scores.append(0.0)
            continue

        emb_vec = np.array(df_ids_to_embeddings[obj_id_str]["embedding"], dtype=np.float32)
        if emb_vec.ndim != 1:
            # In case it is accidentally stored as [1, C]
            emb_vec = emb_vec.reshape(-1)

        class_idx = int(emb_vec.argmax())
        max_score = float(emb_vec[class_idx])

        labels.append(class_idx)
        scores.append(max_score)

    # Add labels to dataframe
    df_points_to_id["ovseg_label"] = labels
    df_points_to_id["ovseg_score"] = scores

    # Save
    out_dir = os.path.join(path, "predicted_labels1")
    os.makedirs(out_dir, exist_ok=True)
    out_csv = os.path.join(out_dir, f"{scene}_predicted_labels_ovseg.csv")
    df_points_to_id.to_csv(out_csv, index=False)

    print(f"Saved OVSeg labels to: {out_csv}")


if __name__ == "__main__":
    main()
