import argparse
import json

import numpy as np
import pandas as pd
import torch


def compute_metrics(
    pred_points,
    pred_labels,
    gt_points,
    gt_labels,
    num_classes,
    threshold,
    eval_classes,
    batch_size=1000,
):
    """
    pred_points: [N_pred, 3] float tensor
    pred_labels: [N_pred] long tensor (0-based)
    gt_points:   [N_gt, 3] float tensor
    gt_labels:   [N_gt] long tensor (0-based)
    num_classes: total number of classes (e.g. 101 for Replica)
    threshold:   distance threshold to count matches
    eval_classes: list of class indices to evaluate (subset of [0..num_classes-1])
    """

    device = pred_points.device
    N, M = pred_points.shape[0], gt_points.shape[0]

    TP = torch.zeros(num_classes, device=device)
    FP = torch.zeros(num_classes, device=device)
    FN = torch.zeros(num_classes, device=device)
    n_i = torch.zeros(num_classes, device=device)

    # Count ground truth points per class
    for i in range(num_classes):
        n_i[i] = (gt_labels == i).sum()

    # Optional: store matched labels if needed in future
    matched_pred_labels = torch.empty(M, dtype=torch.long, device=device)

    for start in range(0, M, batch_size):
        print(start, "out of", M, "GT points processed!")
        end = min(start + batch_size, M)

        batch_gt_points = gt_points[start:end]   # [B, 3]
        batch_gt_labels = gt_labels[start:end]   # [B]

        # Compute pairwise distances: [B, N_pred]
        dists = torch.cdist(batch_gt_points, pred_points)
        min_dists, nn_indices = torch.min(dists, dim=1)  # [B]

        batch_pred_labels = pred_labels[nn_indices]      # [B], 1D

        # Save matches if needed later
        matched_pred_labels[start:end] = batch_pred_labels

        # Count TP, FP, FN for this batch
        for i in range(end - start):
            pred = int(batch_pred_labels[i].item())
            gt = int(batch_gt_labels[i].item())

            # Only consider this GT point if nearest predicted point is close enough
            if min_dists[i] < threshold:
                if pred == gt:
                    TP[gt] += 1
                else:
                    FP[pred] += 1
                    FN[gt] += 1
            else:
                # If you want to treat "too far" as FN, uncomment:
                # FN[gt] += 1
                pass

    # Compute metrics
    mIoU, mAcc, fIoU = 0.0, 0.0, 0.0

    # Total GT points only over eval_classes
    total_points = 0
    for i in range(num_classes):
        if i in eval_classes:
            total_points += n_i[i]

    valid_classes = (n_i > 0).nonzero(as_tuple=True)[0]  # classes that appear in GT

    len_acc = 0
    len_iou = 0

    for i in valid_classes:
        i = int(i.item())
        if i in eval_classes:
            tp, fp, fn = TP[i], FP[i], FN[i]
            denom = tp + fp + fn

            print("Class id:", i)
            print("TP:", tp.item(), "FP:", fp.item(), "FN:", fn.item())

            # IoU
            if denom != 0:
                iou = tp / denom
                weight = n_i[i] / total_points

                mIoU += iou
                fIoU += weight * iou

                print("IoU:", iou.item())
                print("FIoU (weighted):", (weight * iou).item())
                len_iou += 1

            # Accuracy
            if tp + fn != 0:
                acc = tp / (tp + fn)
                mAcc += acc
                print("Acc:", acc.item())
                len_acc += 1

    mIoU = mIoU / max(len_iou, 1)
    mAcc = mAcc / max(len_acc, 1)

    # print("GT points with a nearby prediction (< threshold):", h, "out of", M)


    return {
        "mIoU": mIoU.item(),
        "mAcc": mAcc.item(),
        "F-IoU": fIoU.item(),
    }


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--scene", type=str, required=True)
    parser.add_argument("--path", type=str, default="")
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument(
        "--save",
        action="store_true",
        help="If set, save metrics to <scene>_results_ov.json",
    )
    args = parser.parse_args()

    path = args.path.rstrip("/") + "/"
    dataset = args.dataset
    scene = args.scene
    save = args.save

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # ------------------------------------------------------------------
    # 1) Load predictions (OVSeg-based labels)
    # ------------------------------------------------------------------
    pred_csv = path + f"predicted_labels1/{scene}_predicted_labels_ov_scannet.csv"
    df_pred = pd.read_csv(pred_csv)

    # Points: [N_pred, 3]
    pred_points = torch.tensor(
        df_pred[["x", "y", "z"]].to_numpy(), dtype=torch.float32, device=device
    )

    # Pred labels are already 0-based from OVSeg label generation
    pred_labels = torch.tensor(
        df_pred["labels"].to_numpy(), dtype=torch.long, device=device
    )

    # ------------------------------------------------------------------
    # 2) Load ground truth
    # ------------------------------------------------------------------
    gt_csv = path + f"ground_truth/{scene}_ground_truth.csv"
    df_gt = pd.read_csv(gt_csv)

    gt_points = torch.tensor(
        df_gt[["x", "y", "z"]].to_numpy(), dtype=torch.float32, device=device
    )

    # Assume GT labels are 1-based in CSV -> convert to 0-based
    gt_labels_raw = df_gt["label"].to_numpy()
    # gt_labels = torch.tensor(
    #     gt_labels_raw - 1, dtype=torch.long, device=device
    # )
    gt_labels = torch.tensor(
        gt_labels_raw, dtype=torch.long, device=device
    )

    # ------------------------------------------------------------------
    # 3) Define evaluation classes & number of classes
    # ------------------------------------------------------------------
    if dataset == "Replica":
        # 0..100 (101 classes)
        eval_classes = list(range(101))
        num_classes = 101
    else:
        # ScanNet example (same as your original code)
        eval_classes = [
            0,
            1,
            2,
            3,
            4,
            5,
            6,
            7,
            8,
            9,
            10,
            11,
            13,
            15,
            23,
            27,
            32,
            33,
            35,
            38,
        ]
        num_classes = 40

    # ------------------------------------------------------------------
    # 4) Compute metrics
    # ------------------------------------------------------------------
    metrics = compute_metrics(
        pred_points,
        pred_labels,
        gt_points,
        gt_labels,
        num_classes,
        # threshold=0.25,
        threshold=0.25,
        eval_classes=eval_classes,
    )

    print("Final metrics:", metrics)

    # ------------------------------------------------------------------
    # 5) Optionally save metrics
    # ------------------------------------------------------------------
    if save:
        out_path = scene + "_results_ov_scannet.json"
        with open(out_path, "w") as f:
            data = {
                "mIoU": metrics["mIoU"],
                "mAcc": metrics["mAcc"],
                "F-IoU": metrics["F-IoU"],
            }
            json.dump(data, f, indent=4)
        print("Saved metrics to", out_path)


if __name__ == "__main__":
    main()
