import torch
import numpy as np
import torch.nn.functional as F
from functools import reduce

def bbox_overlaps(boxes, query_boxes):
    """
    Parameters
    ----------
    boxes: (N, 4) ndarray or tensor or variable
    query_boxes: (K, 4) ndarray or tensor or variable
    Returns
    -------
    overlaps: (N, K) overlap between boxes and query_boxes
    """
    if isinstance(boxes, np.ndarray):
        boxes = torch.from_numpy(boxes)
        query_boxes = torch.from_numpy(query_boxes)
        out_fn = lambda x: x.numpy() # If input is ndarray, turn the overlaps back to ndarray when return
    else:
        out_fn = lambda x: x

    box_areas = (boxes[:, 2] - boxes[:, 0] + 1) * \
            (boxes[:, 3] - boxes[:, 1] + 1)
    query_areas = (query_boxes[:, 2] - query_boxes[:, 0] + 1) * \
            (query_boxes[:, 3] - query_boxes[:, 1] + 1)

    iw = (torch.min(boxes[:, 2:3], query_boxes[:, 2:3].t()) - torch.max(boxes[:, 0:1], query_boxes[:, 0:1].t()) + 1).clamp(min=0)
    ih = (torch.min(boxes[:, 3:4], query_boxes[:, 3:4].t()) - torch.max(boxes[:, 1:2], query_boxes[:, 1:2].t()) + 1).clamp(min=0)
    ua = box_areas.view(-1, 1) + query_areas.view(1, -1) - iw * ih
    overlaps = iw * ih / ua
    return out_fn(overlaps)

def bbox_transform(ex_rois, gt_rois):
    ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0
    ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0
    ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths
    ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights

    gt_widths = gt_rois[:, 2] - gt_rois[:, 0] + 1.0
    gt_heights = gt_rois[:, 3] - gt_rois[:, 1] + 1.0
    gt_ctr_x = gt_rois[:, 0] + 0.5 * gt_widths
    gt_ctr_y = gt_rois[:, 1] + 0.5 * gt_heights

    targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths
    targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights
    targets_dw = torch.log(gt_widths / ex_widths)
    targets_dh = torch.log(gt_heights / ex_heights)

    targets = torch.stack(
        (targets_dx, targets_dy, targets_dw, targets_dh),1)

    return targets

def bbox_transform_batch(ex_rois, gt_rois):

    if ex_rois.dim() == 2:
        ex_widths = ex_rois[:, 2] - ex_rois[:, 0] + 1.0
        ex_heights = ex_rois[:, 3] - ex_rois[:, 1] + 1.0
        ex_ctr_x = ex_rois[:, 0] + 0.5 * ex_widths
        ex_ctr_y = ex_rois[:, 1] + 0.5 * ex_heights

        gt_widths = gt_rois[:, :, 2] - gt_rois[:, :, 0] + 1.0
        gt_heights = gt_rois[:, :, 3] - gt_rois[:, :, 1] + 1.0
        gt_ctr_x = gt_rois[:, :, 0] + 0.5 * gt_widths
        gt_ctr_y = gt_rois[:, :, 1] + 0.5 * gt_heights

        targets_dx = (gt_ctr_x - ex_ctr_x.view(1,-1).expand_as(gt_ctr_x)) / ex_widths
        targets_dy = (gt_ctr_y - ex_ctr_y.view(1,-1).expand_as(gt_ctr_y)) / ex_heights
        targets_dw = torch.log(gt_widths / ex_widths.view(1,-1).expand_as(gt_widths))
        targets_dh = torch.log(gt_heights / ex_heights.view(1,-1).expand_as(gt_heights))

    elif ex_rois.dim() == 3:
        ex_widths = ex_rois[:, :, 2] - ex_rois[:, :, 0] + 1.0
        ex_heights = ex_rois[:,:, 3] - ex_rois[:,:, 1] + 1.0
        ex_ctr_x = ex_rois[:, :, 0] + 0.5 * ex_widths
        ex_ctr_y = ex_rois[:, :, 1] + 0.5 * ex_heights

        gt_widths = gt_rois[:, :, 2] - gt_rois[:, :, 0] + 1.0
        gt_heights = gt_rois[:, :, 3] - gt_rois[:, :, 1] + 1.0
        gt_ctr_x = gt_rois[:, :, 0] + 0.5 * gt_widths
        gt_ctr_y = gt_rois[:, :, 1] + 0.5 * gt_heights

        targets_dx = (gt_ctr_x - ex_ctr_x) / ex_widths
        targets_dy = (gt_ctr_y - ex_ctr_y) / ex_heights
        targets_dw = torch.log(gt_widths / ex_widths)
        targets_dh = torch.log(gt_heights / ex_heights)
    else:
        raise ValueError('ex_roi input dimension is not correct.')

    targets = torch.stack(
        (targets_dx, targets_dy, targets_dw, targets_dh),2)

    return targets

def bbox_transform_inv(boxes, deltas, batch_size):
    widths = boxes[:, :, 2] - boxes[:, :, 0] + 1.0
    heights = boxes[:, :, 3] - boxes[:, :, 1] + 1.0
    ctr_x = boxes[:, :, 0] + 0.5 * widths
    ctr_y = boxes[:, :, 1] + 0.5 * heights

    dx = deltas[:, :, 0::4]
    dy = deltas[:, :, 1::4]
    dw = deltas[:, :, 2::4]
    dh = deltas[:, :, 3::4]

    pred_ctr_x = dx * widths.unsqueeze(2) + ctr_x.unsqueeze(2)
    pred_ctr_y = dy * heights.unsqueeze(2) + ctr_y.unsqueeze(2)
    pred_w = torch.exp(dw) * widths.unsqueeze(2)
    pred_h = torch.exp(dh) * heights.unsqueeze(2)

    pred_boxes = deltas.clone()
    # x1
    pred_boxes[:, :, 0::4] = pred_ctr_x - 0.5 * pred_w
    # y1
    pred_boxes[:, :, 1::4] = pred_ctr_y - 0.5 * pred_h
    # x2
    pred_boxes[:, :, 2::4] = pred_ctr_x + 0.5 * pred_w
    # y2
    pred_boxes[:, :, 3::4] = pred_ctr_y + 0.5 * pred_h

    return pred_boxes

def clip_boxes_batch(boxes, im_shape, batch_size):
    """
    Clip boxes to image boundaries.
    """
    num_rois = boxes.size(1)

    boxes[boxes < 0] = 0
    # batch_x = (im_shape[:,0]-1).view(batch_size, 1).expand(batch_size, num_rois)
    # batch_y = (im_shape[:,1]-1).view(batch_size, 1).expand(batch_size, num_rois)

    batch_x = im_shape[:, 1] - 1
    batch_y = im_shape[:, 0] - 1

    boxes[:,:,0][boxes[:,:,0] > batch_x] = batch_x
    boxes[:,:,1][boxes[:,:,1] > batch_y] = batch_y
    boxes[:,:,2][boxes[:,:,2] > batch_x] = batch_x
    boxes[:,:,3][boxes[:,:,3] > batch_y] = batch_y

    return boxes

def clip_boxes(boxes, im_shape, batch_size):

    for i in range(batch_size):
        boxes[i,:,0::4].clamp_(0, im_shape[i, 1]-1)
        boxes[i,:,1::4].clamp_(0, im_shape[i, 0]-1)
        boxes[i,:,2::4].clamp_(0, im_shape[i, 1]-1)
        boxes[i,:,3::4].clamp_(0, im_shape[i, 0]-1)

    return boxes

def compute_targets_pytorch(ex_rois, gt_rois):
    """Compute bounding-box regression targets for an image."""

    assert ex_rois.size(1) == gt_rois.size(1)
    assert ex_rois.size(2) == 4
    assert gt_rois.size(2) == 4
    BBOX_NORMALIZE_MEANS = torch.FloatTensor([0., 0., 0., 0.]).cuda()
    BBOX_NORMALIZE_STDS = torch.FloatTensor([0.1000, 0.1000, 0.2000, 0.2000]).cuda()

    batch_size = ex_rois.size(0)
    rois_per_image = ex_rois.size(1)

    targets = bbox_transform_batch(ex_rois, gt_rois)  # [2, 256, 4]

    # Optionally normalize targets by a precomputed mean and stdev
    targets = ((targets - BBOX_NORMALIZE_MEANS.expand_as(targets))/ BBOX_NORMALIZE_STDS.expand_as(targets))

    return targets

def get_bbox_regression_labels_pytorch(bbox_target_data, labels_batch):
    """Bounding-box regression targets (bbox_target_data) are stored in a
            compact form b x N x (class, tx, ty, tw, th)

            This function expands those targets into the 4-of-4*K representation used
            by the network (i.e. only one class has non-zero targets).

            Returns:
                bbox_target (ndarray): b x N x 4K blob of regression targets
                bbox_inside_weights (ndarray): b x N x 4K blob of loss weights
            """
    batch_size = labels_batch.size(0)
    rois_per_image = labels_batch.size(1)
    clss = labels_batch
    bbox_targets = bbox_target_data.new(batch_size, rois_per_image, 4).zero_()  # [2, 256, 4]
    bbox_inside_weights = bbox_target_data.new(bbox_targets.size()).zero_()  # bbox_inside_weights

    for b in range(batch_size):
        # assert clss[b].sum() > 0
        if clss[b].sum() == 0:
            continue
        inds = torch.nonzero(clss[b] > 0).view(-1)
        for i in range(inds.numel()):
            ind = inds[i]
            bbox_targets[b, ind, :] = bbox_target_data[b, ind, :]
            bbox_inside_weights[b, ind, :] = torch.FloatTensor([1.0000, 1.0000, 1.0000, 1.0000]).cuda()

    return bbox_targets, bbox_inside_weights

def smooth_l1_loss(input, target, beta=1. / 9, size_average=True, reduction=True):
    n = torch.abs(input - target)
    cond = n < beta
    loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
    if reduction == False:
        return loss
    return loss.sum()

def normalize_atten_maps(atten_maps):
    atten_shape = atten_maps.size()
    batch_mins, _ = torch.min(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True)
    batch_maxs, _ = torch.max(atten_maps.view(atten_shape[0:-2] + (-1,)), dim=-1, keepdim=True)
    atten_normed = torch.div(atten_maps.view(atten_shape[0:-2] + (-1,)) - batch_mins,
                             batch_maxs - batch_mins + 1e-7)
    atten_normed = atten_normed.view(atten_shape)
    return atten_normed

def ca_lw(rois_conv, con_loss):
    conv_attention_o_1_1 = torch.sigmoid(torch.mean(rois_conv[0][0], dim=0))
    conv_attention_f_1_1 = torch.sigmoid(torch.mean(rois_conv[0][1], dim=0)).flip(dims=[1])
    conv_attention_o_1_2 = torch.sigmoid(torch.mean(rois_conv[1][0], dim=0))
    conv_attention_f_1_2 = torch.sigmoid(torch.mean(rois_conv[1][1], dim=0)).flip(dims=[1])
    conv_self_attention_gt_1_1 = torch.max(conv_attention_o_1_1, conv_attention_f_1_1).clone().detach()
    conv_self_attention_gt_1_2 = torch.max(conv_attention_o_1_2, conv_attention_f_1_2).clone().detach()
    con_loss += F.mse_loss(conv_self_attention_gt_1_1, conv_attention_o_1_1)
    con_loss += F.mse_loss(conv_self_attention_gt_1_1, conv_attention_f_1_1)
    con_loss += F.mse_loss(conv_self_attention_gt_1_2, conv_attention_o_1_2)
    con_loss += F.mse_loss(conv_self_attention_gt_1_2, conv_attention_f_1_2)
    return con_loss

def ca_iw(keep_inds, rois, num_rois, con_loss):
    union_inds = reduce(np.intersect1d, (keep_inds[0], keep_inds[1], keep_inds[2], keep_inds[3]))
    keep_inds_new = np.concatenate((union_inds, union_inds + num_rois, union_inds + num_rois * 2, union_inds + num_rois * 3))
    num_each = int(keep_inds_new.shape[0] / 4)
    rois_self_attention_1 = torch.mean(rois[keep_inds_new], dim=1)
    rois_self_attention_1 = torch.sigmoid(normalize_atten_maps(rois_self_attention_1))

    rois_self_attention_gt_1 = rois_self_attention_1.clone().detach()
    rois_self_attention_gt_1_1 = torch.max(rois_self_attention_gt_1[0:num_each],
                                           rois_self_attention_gt_1[num_each:num_each * 2].flip(dims=[2]))
    rois_self_attention_gt_1_2 = torch.max(rois_self_attention_gt_1[num_each * 2:num_each * 3],
                                           rois_self_attention_gt_1[num_each * 3:num_each * 4].flip(dims=[2]))
    rois_self_attention_gt_1 = torch.max(rois_self_attention_gt_1_1, rois_self_attention_gt_1_2)
    con_loss += F.mse_loss(rois_self_attention_1[0:num_each], rois_self_attention_gt_1)
    con_loss += F.mse_loss(rois_self_attention_1[num_each:num_each * 2], rois_self_attention_gt_1.flip(dims=[2]))
    con_loss += F.mse_loss(rois_self_attention_1[num_each * 2:num_each * 3], rois_self_attention_gt_1)
    con_loss += F.mse_loss(rois_self_attention_1[num_each * 3:num_each * 4], rois_self_attention_gt_1.flip(dims=[2]))
    return con_loss
