import numpy as np
#import torch
#import config

sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0


def ause_oks(errors: np.ndarray, uncertainty: np.ndarray) -> float:
    """
    errors       : shape (N,) – 0 = perfect, larger = worse
    uncertainty  : shape (N,) – higher = *more* uncertain
    returns      : scalar AUSE (≥ 0 by construction)
    """
    errors       = errors.astype(float)
    uncertainty  = uncertainty.astype(float)
    N            = len(errors)
    total_err    = errors.sum()

    # ---- 1. Risk after removing the k most-uncertain samples ----
    order_unc    = np.argsort(-uncertainty)         # indices, descending unc
    removed_err  = np.cumsum(errors[order_unc])     # cumulative error of removed set
    kept_count   = N - np.arange(1, N + 1)          # size of remaining set
    risk_unc     = (total_err - removed_err) / np.maximum(kept_count, 1)
    risk_unc[-1] = 0.0                              # convention when nothing is left

    # ---- 2. Oracle risk (remove the k worst errors) ----
    order_err    = np.argsort(-errors)              # descending error
    removed_err  = np.cumsum(errors[order_err])
    risk_oracle  = (total_err - removed_err) / np.maximum(kept_count, 1)
    risk_oracle[-1] = 0.0

    # ---- 3. Sparsification error & AUSE ----
    se   = risk_unc - risk_oracle                   # guaranteed ≥ 0
    ause = se.mean()
    return ause
# def ause_oks(errors: np.ndarray, uncertainty: np.ndarray) -> float:
#     """
#     errors       : shape (N,) –  0 = perfect, larger = worse (e.g. 1‑OKS or |ŷ‑y|)
#     uncertainty  : shape (N,) – higher = more uncertain
#     returns      : scalar AUSE  (smaller is better, 0 = oracle)
#     """
#     N = len(errors)
#     # 1. Sparsification curve – remove most‑uncertain first
#     idx_u = np.argsort(-uncertainty)        # descending
#     risk_u = np.cumsum(errors[idx_u]) / np.arange(1, N + 1)

#     # 2. Oracle curve – remove largest‑error first
#     idx_o = np.argsort(-errors)             # descending
#     risk_o = np.cumsum(errors[idx_o]) / np.arange(1, N + 1)

#     # 3. Sparsification *error* curve
#     se = risk_o - risk_u                    # shape (N,)
#     # 4. Area under sparsification‑error (uniform x‑spacing → mean == trapezoid)
#     return se.mean()                        # identical to np.trapz(se, dx=1/N)


def OKS(pts_gt, pts_pred, area, img_shape):
    # print(pts_gt.shape)
    img_w, img_h = img_shape
    gt_visible = np.where((pts_gt[:, 2] > 0) & (pts_gt[:, 0] < img_w) & (pts_gt[:, 0] > 0) & (
                pts_gt[:, 1] < img_h) & (pts_gt[:, 1] > 0))[0]
    pts_gt = pts_gt[gt_visible, :2]
    pts_pred = pts_pred[gt_visible, :2]
    sig = (sigmas[gt_visible] * 2) ** 2

    dist = np.square(pts_gt - pts_pred).sum(axis=1)
    result = np.exp(- dist / (sig * area * 2))
    # print(result)
    if len(result) == 0:
        return -10
    result = result.mean()
    return result


def max_oks(pts_gt, pt_pred, area, img_shape):
    max_oks = -1
    max_id = -1
    for i in range(len(pts_gt)):
        oks = OKS(pts_gt[i], pt_pred, area[i], img_shape)
        if oks > max_oks:
            max_oks = oks
            max_id = i
    return max_oks, max_id


def oks_one(pts_gt, pts_pred, area, img_shape):

    gt_valid = np.where(area > 10)[0]
    pts_gt = pts_gt[gt_valid]
    area = area[gt_valid]

    oks_list = []
    matched_ids = []
    # oks_mask = []
    # vis_list = []
    # vis_mask = []
    for i in range(len(pts_pred)):
        oks, idd = max_oks(pts_gt, pts_pred[i], area, img_shape)
        oks_list.append(oks)
        matched_ids.append(idd)

    return oks_list, matched_ids


def oks_batch(pts_gt, pts_pred, area):
    pts_gt = pts_gt.cpu().numpy()
    area = area.cpu().numpy()
    pts_pred = pts_pred.cpu().numpy()

    bsize = pts_gt.shape[0]
    results = []
    masks = []
    viss = []
    vis_masks = []
    for i in range(bsize):
        oks, oks_mask, vis, vis_mask = oks_one(pts_gt[i], pts_pred[i], area[i])
        results.append(oks)
        masks.append(oks_mask)
        viss.append(vis)
        vis_masks.append(vis_mask)
    results = np.float32(results)
    masks = np.float32(masks)
    viss = np.float32(viss)
    vis_masks = np.float32(vis_masks)
    return results, masks, viss, vis_masks


def diff_one(pts_gt, pts_pred, area):
    gt_valid = np.where(area > 10)[0]
    pts_gt = pts_gt[gt_valid]
    area = area[gt_valid]

    diff_list = []
    vis_list = []
    for i in range(len(pts_pred)):
        pred = pts_pred[i]
        oks, idd = max_oks(pts_gt, pred, area)

        if idd < 0:
            diff = np.zeros([17, 2])
            idx_visible = np.zeros([17])
        else:
            gt = pts_gt[idd]
            diff = gt[:, :2] - pred
            idx_visible = np.float32(
                (gt[:, 2] > 0) & (gt[:, 0] < config.out_size) & (gt[:, 0] > 0) & (gt[:, 1] < config.out_size) & (
                            gt[:, 1] > 0) &
                (pred[:, 0] < config.out_size) & (pred[:, 0] > 0) & (pred[:, 1] < config.out_size) & (pred[:, 1] > 0)
            )
        diff_list.append(diff)
        vis_list.append(idx_visible)
    return diff_list, vis_list


def diff_batch(pts_gt, pts_pred, area):
    pts_gt = pts_gt.cpu().numpy()
    area = area.cpu().numpy()
    pts_pred = pts_pred.cpu().numpy()

    bsize = pts_gt.shape[0]
    results = []
    masks = []
    for i in range(bsize):
        diff, mask = diff_one(pts_gt[i], pts_pred[i], area[i])
        results.append(diff)
        masks.append(mask)
    results = np.float32(results)
    masks = np.float32(masks)[..., None]
    return results, masks


def make_conf_matrix(pts_gt, pts_pred, area):
    indices = []
    for i in range(len(pts_pred)):
        oks, idd = max_oks(pts_gt, pts_pred[i], area)
        if oks > 0.6:
            indices.append(idd)
        elif oks < 0:
            indices.append(-2)
        else:
            indices.append(-1)
    mtx = np.zeros([len(indices), len(indices)], dtype=np.float32)
    mask = np.zeros([len(indices), len(indices)], dtype=np.float32) + 0.1
    for i in range(len(indices) - 1):
        for j in range(i + 1, len(indices)):
            if indices[i] >= 0 and indices[j] >= 0 and indices[i] == indices[j]:
                # print('positive')
                mtx[i, j] = 1
                mask[i, j] = 1
            elif indices[i] == -2 or indices[j] == -2:
                mask[i, j] = 0
    return mtx, mask


def conf_matrix_batch(pts_gt, pts_pred, area):
    pts_gt = pts_gt.cpu().numpy()
    area = area.cpu().numpy()
    pts_pred = pts_pred.cpu().numpy()

    bsize = pts_gt.shape[0]
    mtxs = []
    masks = []
    for i in range(bsize):
        mtx, mask = make_conf_matrix(pts_gt[i], pts_pred[i], area[i])
        mtxs.append(mtx)
        masks.append(mask)
    mtxs = np.float32(mtxs)
    masks = np.float32(masks)
    return mtxs, masks


import numpy as np
def oks_nms(poses, scores, thresh, sigmas=None, in_vis_thre=None):
    if len(poses) == 0: return []
    areas = (np.max(poses[:, :, 0], axis=1) - np.min(poses[:, :, 0], axis=1)) * \
            (np.max(poses[:, :, 1], axis=1) - np.min(poses[:, :, 1], axis=1))
    poses = poses.reshape(poses.shape[0], -1)

    order = scores.argsort()[::-1]

    keep = []
    keep_ind = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        oks_ovr = oks_iou(poses[i], poses[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre)
        inds = np.where(oks_ovr <= thresh)[0]
        nms_inds = np.where(oks_ovr > thresh)[0]
        nms_inds = order[nms_inds + 1]
        keep_ind.append(nms_inds.tolist())
        order = order[inds + 1]

    return keep, keep_ind

def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
    vars = (sigmas * 2) ** 2
    xg = g[0::3]
    yg = g[1::3]
    vg = g[2::3]
    ious = np.zeros((d.shape[0]))
    for n_d in range(0, d.shape[0]):
        xd = d[n_d, 0::3]
        yd = d[n_d, 1::3]
        vd = d[n_d, 2::3]
        dx = xd - xg
        dy = yd - yg
        e = (dx ** 2 + dy ** 2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
        if in_vis_thre is not None:
            ind = list(vg >= in_vis_thre) and list(vd >= in_vis_thre)
            e = e[ind]
        ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
    return ious
