# Copyright (c) Facebook, Inc. and its affiliates.
# 
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random

import torch
import torch.nn as nn
import numpy as np
import sys
import os
import time
# sys.path.append(os.path.join(os.getcwd(), "lib"))  # HACK add the lib folder
# from lib.ap_helper.ap_helper_fcos import parse_predictions
# from utils.box_util import get_3d_box, get_3d_box_batch, box3d_iou
import torch.nn.functional as F
from utils.rle import rle_decode, rle_encode
from config.config import CONF
from lib.loss_helper.loss_joint import get_iou

# def eval_ref_one_sample(pred_bbox, gt_bbox):
#     """ Evaluate one reference prediction

#     Args:
#         pred_bbox: 8 corners of prediction bounding box, (8, 3)
#         gt_bbox: 8 corners of ground truth bounding box, (8, 3)
#     Returns:
#         iou: intersection over union score
#     """

#     iou = box3d_iou(pred_bbox, gt_bbox)

#     return iou


@torch.no_grad()
def get_eval_clip(data_dict, k=3):

    batch_size, len_num_max = data_dict['ann_id_list'].shape[:2]
    weights = data_dict["coarse_weights"].reshape(batch_size*len_num_max, -1)
    pred_ref = torch.argmax(weights, 1)
    pred_ref_topk = torch.topk(weights, k=k, dim=1)[1] 
    lang_num = data_dict["lang_num"]

    ious = []
    pred_ref = pred_ref.reshape(batch_size, len_num_max)

    topk_ious = []
    pred_ref_topk = pred_ref_topk.reshape(batch_size, len_num_max, -1)

    max_ious = []

    # pred
    for i in range(batch_size):
        for j in range(len_num_max):
            if j < lang_num[i]:
                gt_mask = data_dict["gt_masks"][i][j]

                # iou
                pred_ref_id = pred_ref[i][j]
                pred_mask = data_dict["pred_masks"][i][pred_ref_id]
                iou = get_iou(pred_mask, gt_mask)
                ious.append(iou.item())
                
                # topk
                max_iou = 0
                sum_max = 0
                for pred_ref_id in pred_ref_topk[i][j]:
                    if pred_ref_id >= len(data_dict["pred_masks"][i]):
                        continue
                    pred_mask = data_dict["pred_masks"][i][pred_ref_id]

                    iou = get_iou(pred_mask, gt_mask)
                    if iou >= max_iou:
                        max_iou = iou
                topk_ious.append(max_iou.item())
                pred_mask = data_dict["pred_masks"][i]
                max_iou = get_iou(pred_mask, gt_mask).max()
                max_ious.append(max_iou.item())

    # store
    data_dict["ref_iou"] = ious
    data_dict["topk_iou"] = topk_ious
    data_dict["max_iou"] = max_ious
    return data_dict

@torch.no_grad()
def get_eval_m3d(data_dict, thr=0.5):

    batch_size, len_num_max = data_dict['ann_id_list'].shape[:2]
    weights = data_dict["coarse_weights"]
    lang_num = data_dict["lang_num"]

    ious = []
    nt_labels = []
    # pred
    for i in range(batch_size):
        for j in range(len_num_max):
            if j < lang_num[i]:
                gt_mask = data_dict["gt_masks"][i][j]

                # iou
                pred_ref_id = (weights[i][j] > thr).nonzero().view(-1)
                pred_mask = torch.zeros_like(data_dict["pred_masks"][i][0])
                if pred_ref_id.shape[0] > 0:
                    for idx in pred_ref_id:
                        pred_mask[data_dict["pred_masks"][i][idx] == 1] = 1
                    nt_labels.append(0.0)
                else:
                    nt_labels.append(1.0)
                iou = get_iou(pred_mask, gt_mask)
                ious.append(iou.item())

    # store
    data_dict["ref_iou"] = ious
    data_dict["nt_labels"] = nt_labels
    return data_dict

@torch.no_grad()
def get_eval_cls(data_dict):
    pred_lang_cat = torch.argmax(data_dict["lang_scores"], 1)
    object_cat = data_dict["object_cat_list"].flatten(0, 1)
    data_dict["lang_acc"] = (pred_lang_cat == object_cat).float().mean()
    return data_dict