from mmrotate.datasets import DOTADataset
from mmrotate.datasets.builder import ROTATED_DATASETS

from collections import defaultdict
from functools import partial
import re

import mmcv
import numpy as np
import torch
from mmcv.ops import nms_rotated


@ROTATED_DATASETS.register_module()
class DebugDOTADataset(DOTADataset):

    def merge_det(self, results, nproc=4):

        collector = defaultdict(list)
        for idx in range(len(self)):
            result = results[idx]
            img_id = self.img_ids[idx]
            splitname = img_id.split("__")
            oriname = splitname[0]
            pattern1 = re.compile(r"__\d+___\d+")
            x_y = re.findall(pattern1, img_id)
            x_y_2 = re.findall(r"\d+", x_y[0])
            x, y = int(x_y_2[0]), int(x_y_2[1])
            new_result = []
            for i, dets in enumerate(result):
                bboxes, scores = dets[:, :-1], dets[:, [-1]]
                ori_bboxes = bboxes.copy()
                ori_bboxes[..., :2] = ori_bboxes[..., :2] + np.array(
                    [x, y], dtype=np.float32
                )
                labels = np.zeros((bboxes.shape[0], 1)) + i
                new_result.append(np.concatenate([labels, ori_bboxes, scores], axis=1))

            new_result = np.concatenate(new_result, axis=0)
            collector[oriname].append(new_result)

        merge_func = partial(_merge_func, CLASSES=self.CLASSES, iou_thr=0.1)
        # Force to Single Processing.
        if nproc <= 1 or True:
            print("Single processing")
            merged_results = mmcv.track_iter_progress(
                (map(merge_func, collector.items()), len(collector))
            )
        else:
            print("Multiple processing")
            merged_results = mmcv.track_parallel_progress(
                merge_func, list(collector.items()), nproc
            )

        return zip(*merged_results)


def _merge_func(info, CLASSES, iou_thr):
    """Merging patch bboxes into full image.

    Args:
        CLASSES (list): Label category.
        iou_thr (float): Threshold of IoU.
    """
    img_id, label_dets = info
    label_dets = np.concatenate(label_dets, axis=0)

    labels, dets = label_dets[:, 0], label_dets[:, 1:]

    big_img_results = []
    for i in range(len(CLASSES)):
        if len(dets[labels == i]) == 0:
            big_img_results.append(dets[labels == i])
        else:
            try:
                cls_dets = torch.from_numpy(dets[labels == i]).cuda()
            except:  # noqa: E722
                cls_dets = torch.from_numpy(dets[labels == i])
            nms_dets, keep_inds = nms_rotated(cls_dets[:, :5], cls_dets[:, -1], iou_thr)
            big_img_results.append(nms_dets.cpu().numpy())
    return img_id, big_img_results
