import pandas as pd
import torch
import json
import numpy as np
import argparse
def compute_metrics(pred_points, pred_labels, gt_points, gt_labels, num_classes, threshold,eval_classes,batch_size=1000):
    h = 0
    device = pred_points.device
    N, M = pred_points.shape[0], gt_points.shape[0]

    TP = torch.zeros(num_classes, device=device)
    FP = torch.zeros(num_classes, device=device)
    FN = torch.zeros(num_classes, device=device)
    n_i = torch.zeros(num_classes, device=device)

    # Count ground truth points per class
    for i in range(num_classes):
        n_i[i] = (gt_labels == i).sum()

    matched_pred_labels = torch.empty(M, dtype=torch.long, device=device)
    
    for start in range(0, M, batch_size):
        print(start,' out of ', M,' Processed!')
        end = min(start + batch_size, M)
        batch_gt_points = gt_points[start:end]  # [B, 3]
        batch_gt_labels = gt_labels[start:end]  # [B]

        # Compute distances [B, N]
        dists = torch.cdist(batch_gt_points, pred_points)  # [B, N]
        nn_indices = torch.argmin(dists, dim=1)  # [B]
        min_dists, min_idxs = torch.min(dists, dim=1)
    
        batch_pred_labels = pred_labels[nn_indices]  # [B]
        batch_pred_labels = batch_pred_labels.squeeze(1)
        # Update matches
        matched_pred_labels[start:end] = batch_pred_labels

        # Count TP, FP, FN in batch
        for i in range(end-start):
           
            pred = int(batch_pred_labels[i].cpu().numpy().item())    
            gt = int(batch_gt_labels[i].cpu().numpy().item())
            # if(gt in eval_classes and pred in eval_classes):
            if(min_dists[i]<threshold):
                h+=1
            if pred == gt:
                if(min_dists[i]<threshold):
                    TP[gt] += 1
            else:
                if(min_dists[i]<threshold):
                    FP[pred] += 1
                FN[gt] += 1

    # Compute metrics
    mIoU, mAcc, fIoU = 0.0, 0.0, 0.0
    total_points = 0
    for i in range(num_classes):
        if(i in eval_classes):
            total_points += n_i[i]
    valid_classes = (n_i > 0).nonzero(as_tuple=True)[0]
    len_acc = 0
    len_iou = 0
    for i in valid_classes:
        if(i in eval_classes):
            tp, fp, fn = TP[i], FP[i], FN[i]
            denom = tp + fp + fn
            print("Object id: ",i)
            print('TP: ',tp,'FP: ',fp,'FN: ',fn)
            if(denom!=0):
                iou = tp / denom
                weight = n_i[i] / total_points

                mIoU += iou
                fIoU += weight * iou 
                print('IoU: ',iou)
                print('FIoU: ',weight*iou)
                len_iou+=1
            
            if(tp+fn!=0):
                acc = tp / (tp+fn)
                mAcc += acc
                print('Acc: ',acc) 
                len_acc+=1

    mIoU /= len_iou
    mAcc /= len_acc
    
    return {
        'mIoU': mIoU.item(),
        'mAcc': mAcc.item(),
        'F-IoU': fIoU.item()
        }

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--scene", type=str)
    parser.add_argument("--path", type=str,default='')
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--save", type=str,default=False)
    args = parser.parse_args()

    path = args.path
    dataset = args.dataset
    scene = args.scene
    save = args.save

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    df = pd.read_csv(path+'predicted_labels/'+dataset+'/'+scene+'_predicted_labels.csv')
    # Step 2: Extract XYZ coordinates and RGB colors
    pred_points = torch.tensor(df[['x', 'y', 'z']].to_numpy()).to(device)

    pred_labels = torch.tensor(df[['labels']].to_numpy()).to(device)-1

    df = pd.read_csv(path+'ground_truth/'+dataset+'/'+scene+'_ground_truth.csv')
    # Step 2: Extract XYZ coordinates and RGB colors
    gt_points = torch.tensor(df[['x', 'y', 'z']].to_numpy()).to(device)

    gt_labels = torch.tensor(df[['label']].to_numpy()).to(device).squeeze()-1
    if(dataset=='Replica'):
        eval_classes = list(np.arange(101))
        class_num = 101
    else:

        eval_classes = [0,1,2,3,4,5,6,7,8,9,10,11,13,15,23,27,32,33,35,38]
        class_num = 40

    metrics = compute_metrics(pred_points,pred_labels,gt_points,gt_labels,class_num,0.25,eval_classes)

    print(metrics)
    if(save):
        with open(scene+'_results.json', 'w') as f:
            data = {}
            data['mIoU'] = metrics['mIoU']
            data['mAcc'] = metrics['mAcc']
            data['F-IoU'] =  metrics['F-IoU']
            json.dump(data, f)

if __name__ == '__main__':
    main()