import os
import numpy as np


def get_analysis_targets_cls(predictions, labels_to_class_names={}, base_dir='', target="accuracy"):
    targets_dict = {name:{} for name in labels_to_class_names.values()}
    preds,gts,paths = predictions['preds'], predictions['labels'], predictions['paths']
    for pred, gt, path in zip(preds, gts, paths):
        path = os.path.join(base_dir, path)
        if target == 'accuracy':
            correctness = 1 if (pred == gt) else 0
        else:
            raise NotImplementedError
        targets_dict[labels_to_class_names[gt]][path] = correctness
    return targets_dict

def get_analysis_targets_det(predictions, labels_to_class_names={}, base_dir='', target="iou"):
    def xywh_to_xyxy(box):
        x, y, w, h = box
        x1 = x - w / 2
        y1 = y - h / 2
        x2 = x + w / 2
        y2 = y + h / 2
        return [x1, y1, x2, y2]

    def calculate_iou(box1, box2):
        x1_1, y1_1, x2_1, y2_1 = xywh_to_xyxy(box1)
        x1_2, y1_2, x2_2, y2_2 = xywh_to_xyxy(box2)

        inter_x1 = max(x1_1, x1_2)
        inter_y1 = max(y1_1, y1_2)
        inter_x2 = min(x2_1, x2_2)
        inter_y2 = min(y2_1, y2_2)

        inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)

        box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
        box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)

        iou = inter_area / (box1_area + box2_area - inter_area)
        return iou

    def compute_ious(gt_boxes, pred_boxes):
        ious = np.zeros((len(gt_boxes), len(pred_boxes)))
        for i, gt_box in enumerate(gt_boxes):
            for j, pred_box in enumerate(pred_boxes):
                ious[i, j] = calculate_iou(gt_box, pred_box)
        return np.max(ious, axis=1)

    targets_dict = {name:{} for name in labels_to_class_names.values()}
    for preds, gts, cls_list, path in zip(predictions["preds_box"], predictions["labels_box"], predictions["labels_cls"], predictions["paths"]):
        if target == 'iou':
            if len(preds) == 0:
                results = [0 for i in range(len(gts))]
            else:
                results = compute_ious(gts, preds)
        else:
            raise NotImplementedError
        for result, gt, cls in zip(results, gts, cls_list):
            path = os.path.basename(path.replace("\\","/"))
            targets_dict[labels_to_class_names[cls]][tuple([os.path.join(base_dir, path), tuple(gt)])] = result
    return targets_dict

def get_analysis_targets_pose(predictions, base_dir='', target="oks"):
    def oks(pred, label):
        sigma = np.array([.026, .025, .025, .035, .035, .079, .079, .072, .072,
                .062, .062, .107, .107, .087, .087, .089, .089])
        x_label = label[:, 0]
        y_label = label[:, 1]
        min_x, max_x = x_label.min(), x_label.max()
        min_y, max_y = y_label.min(), y_label.max()
        area = (max_x - min_x) * (max_y - min_y) + 1e-10
        x_pred = pred[:, 0]
        y_pred = pred[:, 1]
        dist = (x_pred - x_label) ** 2 + (y_pred - y_label) ** 2
        dist_reg = dist / (2 * area * sigma ** 2)
        dist_reg = dist_reg[label[:, 2] == 1]
        return np.exp(-dist_reg).mean()
     
    targets_dict = {'person':{}}
    preds,gts,paths = predictions['preds'], predictions['labels'], predictions['paths']
    for pred, gt, path in zip(preds, gts, paths):
        path = os.path.join(base_dir, path)
        if target == 'oks':
            correctness = oks(pred, gt)
        else:
            raise NotImplementedError
        targets_dict['person'][path] = correctness
    return targets_dict