import pdb
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.models as models
from torch.autograd import Variable
import numpy as np
from model.utils.config import cfg
from model.rpn.rpn import _RPN
from model.rpn.bbox_transform import bbox_transform_inv, clip_boxes

from model.roi_layers import ROIAlign, ROIPool, nms

# from model.roi_pooling.modules.roi_pool import _RoIPooling
# from model.roi_align.modules.roi_align import RoIAlignAvg

from model.rpn.proposal_target_layer_cascade import _ProposalTargetLayer
from model.utils.net_utils import _smooth_l1_loss, _crop_pool_layer, _affine_grid_gen, _affine_theta

class _fasterRCNN(nn.Module):
    """ faster RCNN """
    def __init__(self, classes, class_agnostic, mode='train', bias_box=False):
        super(_fasterRCNN, self).__init__()
        self.classes = classes
        self.n_classes = len(classes)
        self.class_agnostic = class_agnostic
        self.mode = mode
        self.bias_box = bias_box

        # loss
        self.RCNN_loss_cls = 0
        self.RCNN_loss_bbox = 0

        # define rpn
        self.RCNN_rpn = _RPN(self.dout_base_model)
        self.RCNN_proposal_target = _ProposalTargetLayer(self.n_classes)

        # self.RCNN_roi_pool = _RoIPooling(cfg.POOLING_SIZE, cfg.POOLING_SIZE, 1.0/16.0)
        # self.RCNN_roi_align = RoIAlignAvg(cfg.POOLING_SIZE, cfg.POOLING_SIZE, 1.0/16.0)

        self.RCNN_roi_pool = ROIPool((cfg.POOLING_SIZE, cfg.POOLING_SIZE), 1.0/16.0)
        self.RCNN_roi_align = ROIAlign((cfg.POOLING_SIZE, cfg.POOLING_SIZE), 1.0/16.0, 0)

    def deform_box(self, box, img_height, img_width, bias_rate = 0.1):
        x1, y1, x2, y2 = box
        width = x2 - x1
        height = y2 - y1
        coin = random.random()

        if width >= 1 and height >= 1 and coin >= 0.5:
            max_bias_width = int(bias_rate * width)
            max_bias_height = int(bias_rate * height)

            x1_min = int(max(0, x1 - max_bias_width))
            x1_max = int(min(x2, x1 + max_bias_width))
            x1_ = random.randint(x1_min, x1_max)
            box[0] = x1_

            y1_min = int(max(0, y1 - max_bias_height))
            y1_max = int(min(y2, y1 + max_bias_height))
            y1_ = random.randint(y1_min, y1_max)
            box[1] = y1_

            x2_min = int(max(x1_ + 1, x2 - max_bias_width))
            x2_max = int(min(img_width, x2 + max_bias_width))
            x2_ = random.randint(x2_min, x2_max)
            box[2] = x2_

            y2_min = int(max(y1_ + 1, y2 - max_bias_height))
            y2_max = int(min(img_height, y2 + max_bias_height))
            y2_ = random.randint(y2_min, y2_max)
            box[3] = y2_

        return box

    def forward(self, im_data, im_info, gt_boxes, num_boxes, thr = 0.05, use_context = False):
        batch_size = im_data.size(0)

        im_info = im_info.data
        gt_boxes = gt_boxes.data
        num_boxes = num_boxes.data

        gt_boxes_pos = gt_boxes[:, :, :4]
        gt_boxes_pos_cls = gt_boxes[:, :, :5]
        gt_boxes_cls = gt_boxes[:, :, 4]
        if self.training:
            gt_boxes_id = gt_boxes[:, :, 5]

        # feed image data to base model to obtain base feature map
        base_feat = self.RCNN_base(im_data)

        # feed base feature map tp RPN to obtain rois
        rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes_pos_cls, num_boxes)

        # if it is training phrase, then use ground truth bboxes for refining
        if self.training:
            roi_data = self.RCNN_proposal_target(rois, gt_boxes_pos_cls, num_boxes)
            rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data

            rois_label = Variable(rois_label.view(-1).long())
            rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
            rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
            rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
        else:
            rois_label = None
            rois_target = None
            rois_inside_ws = None
            rois_outside_ws = None
            rpn_loss_cls = 0
            rpn_loss_bbox = 0

        rois = Variable(rois)

        # do roi pooling based on predicted rois
        if cfg.POOLING_MODE == 'align':
            pooled_feat = self.RCNN_roi_align(base_feat, rois.view(-1, 5))
        elif cfg.POOLING_MODE == 'pool':
            pooled_feat = self.RCNN_roi_pool(base_feat, rois.view(-1,5))

        # feed pooled features to top model
        pooled_feat = self._head_to_tail(pooled_feat)

        # compute bbox offset
        bbox_pred = self.RCNN_bbox_pred(pooled_feat)
        if self.training and not self.class_agnostic:
            # select the corresponding columns according to roi labels
            bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
            bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
            bbox_pred = bbox_pred_select.squeeze(1)

        # compute object classification probability
        cls_score = self.RCNN_cls_score(pooled_feat)
        cls_prob = F.softmax(cls_score, 1)

        RCNN_loss_cls = 0
        RCNN_loss_bbox = 0

        if self.training:
            # classification loss
            RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)

            # bounding box regression L1 loss
            RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)


        cls_prob = cls_prob.view(batch_size, rois.size(1), -1)
        bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1)

        # for training, get the embedding and category for each gt object
        if self.training or self.mode in ['sgc', 'predcls']:
            gt_rois_list = []
            gt_rois_label_list = []
            gt_rois_id_list = []
            gt_context_list = []
            id_pair_list = []

            for img_id in range(num_boxes.shape[0]):

                img_rois_list = []
                img_rois_id_list = []
                for gt_box_id in range(num_boxes[img_id].long().item()):
                    tmp_gt_box_pos = gt_boxes_pos[img_id][gt_box_id]
                    if self.training and self.bias_box:
                        tmp_gt_box_pos = self.deform_box(tmp_gt_box_pos, im_data.size(2), im_data.size(3))
                    gt_roi = torch.cat([torch.ones_like(tmp_gt_box_pos[0].unsqueeze(0)) * img_id, tmp_gt_box_pos],
                                       dim=0)
                    gt_rois_list.append(gt_roi)
                    img_rois_list.append(tmp_gt_box_pos)

                    tmp_gt_rois_label = gt_boxes_cls[img_id][gt_box_id]
                    gt_rois_label_list.append(tmp_gt_rois_label.unsqueeze(0))

                    tmp_gt_rois_id = gt_boxes_id[img_id][gt_box_id]
                    gt_rois_id_list.append(tmp_gt_rois_id.unsqueeze(0))
                    img_rois_id_list.append(tmp_gt_rois_id.unsqueeze(0))

                if use_context and len(img_rois_list) != 0:
                    # get the location of each context
                    img_rois = torch.stack(img_rois_list, dim=0)
                    img_rois_hstack = torch.stack([img_rois] * img_rois.shape[0], dim=0)
                    img_rois_vstack = torch.stack([img_rois] * img_rois.shape[0], dim=1)
                    img_rois_stack = torch.cat([img_rois_vstack, img_rois_hstack], dim=-1)

                    context_x1 = torch.min(
                        torch.stack([img_rois_stack[:,:,0], img_rois_stack[:,:,4]], dim=-1), dim=-1)[0]
                    context_y1 = torch.min(
                        torch.stack([img_rois_stack[:,:,1], img_rois_stack[:,:,5]], dim=-1), dim=-1)[0]
                    context_x2 = torch.max(
                        torch.stack([img_rois_stack[:,:,2], img_rois_stack[:,:,6]], dim=-1), dim=-1)[0]
                    context_y2 = torch.max(
                        torch.stack([img_rois_stack[:,:,3], img_rois_stack[:,:,7]], dim=-1), dim=-1)[0]
                    context_img_id = torch.ones_like(context_x1) * img_id

                    context_rois = torch.stack([context_img_id, context_x1, context_y1, context_x2, context_y2], dim=-1)
                    context_rois = context_rois.view((-1, 5))
                    gt_context_list.append(context_rois)

                    # get the object id pair for each context
                    img_rois_id = torch.cat(img_rois_id_list, dim=0)
                    img_rois_id_hstack = torch.stack([img_rois_id] * img_rois_id.shape[0], dim=0)
                    img_rois_id_vstack = torch.stack([img_rois_id] * img_rois_id.shape[0], dim=1)
                    img_rois_id_pair = torch.stack([img_rois_id_vstack, img_rois_id_hstack], dim=-1)
                    img_rois_id_pair = img_rois_id_pair.view((-1, 2))
                    id_pair_list.append(img_rois_id_pair)

            gt_rois = torch.stack(gt_rois_list, dim=0)
            gt_rois_label = torch.cat(gt_rois_label_list, dim=0)
            gt_rois_id = torch.cat(gt_rois_id_list, dim=0)
            gt_context = torch.cat(gt_context_list, dim=0)
            id_pair = torch.cat(id_pair_list, dim=0)

            if cfg.POOLING_MODE == 'align':
                pooled_gt_feat = self.RCNN_roi_align(base_feat, gt_rois)
                if use_context:
                    pooled_context_feat = self.RCNN_roi_align(base_feat, gt_context)
            elif cfg.POOLING_MODE == 'pool':
                pooled_gt_feat = self.RCNN_roi_pool(base_feat, gt_rois)
                if use_context:
                    pooled_context_feat = self.RCNN_roi_pool(base_feat, gt_context)

            # define the outputs
            pooled_gt_feat = self._head_to_tail(pooled_gt_feat)
            output_feat = pooled_gt_feat
            output_id = gt_rois_id
            output_roi = gt_rois

            if use_context:
                pooled_context_feat = self._head_to_tail(pooled_context_feat)
                output_context_feat = pooled_context_feat
                id_pair_list = id_pair.long().tolist()
                id_pair_tuple = [tuple(tmp_id_pair) for tmp_id_pair in id_pair_list]
                context_dict = dict(zip(id_pair_tuple, pooled_context_feat))

            # classify the gt boxes for scene graph classification, otherwise return gt classes
            if self.mode == 'sgc':
                output_cls_prob = self.RCNN_cls_score(pooled_gt_feat)
                output_conf, output_label = torch.max(output_cls_prob, dim=1)
            else:
                output_label = gt_rois_label
                output_conf = torch.ones_like(output_label)

        # for inference, get the embedding and category for high-confidence rois
        else:
            pred_rois_list = []
            pred_rois_label_list = []
            pred_rois_conf_list = []
            pred_context_list = []

            for img_id in range(batch_size):
                tmp_boxes = rois[img_id, :, 1:5].unsqueeze(0).data
                tmp_cls_prob = cls_prob[img_id, ...].unsqueeze(0).data
                tmp_bbox_pred = bbox_pred[img_id, ...].unsqueeze(0).data
                img_rois_list = []

                # Transform the rois according the regression results
                if cfg.TEST.BBOX_REG:
                    # Apply bounding-box regression deltas
                    box_deltas = tmp_bbox_pred
                    if cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED:
                        # Optionally normalize targets by a precomputed mean and stdev
                        if self.class_agnostic:
                            if cfg.CUDA:
                                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(
                                    cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
                                             + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
                            else:
                                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS) \
                                             + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS)
                            box_deltas = box_deltas.view(1, -1, 4)
                        else:
                            if cfg.CUDA:
                                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(
                                    cfg.TRAIN.BBOX_NORMALIZE_STDS).cuda() \
                                             + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS).cuda()
                            else:
                                box_deltas = box_deltas.view(-1, 4) * torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_STDS) \
                                             + torch.FloatTensor(cfg.TRAIN.BBOX_NORMALIZE_MEANS)
                            box_deltas = box_deltas.view(1, -1, 4 * self.n_classes)

                    pred_boxes = bbox_transform_inv(tmp_boxes, box_deltas, 1)
                    pred_boxes = clip_boxes(pred_boxes, im_info[img_id, :].unsqueeze(0), 1)
                else:
                    # Simply repeat the boxes, once for each class
                    pred_boxes = np.tile(tmp_boxes, (1, tmp_cls_prob.shape[1]))

                # Get the rois with high-confidence
                pred_boxes = pred_boxes.squeeze(0)
                tmp_cls_prob = tmp_cls_prob.squeeze()
                tmp_bbox_pred = tmp_bbox_pred.squeeze()
                for j in range(1, self.n_classes):
                    tmp_inds = torch.nonzero(tmp_cls_prob[:, j] > thr).view(-1)
                    tmp_cls_score = tmp_cls_prob[:, j][tmp_inds]
                    _, tmp_order = torch.sort(tmp_cls_score, 0, True)
                    if self.class_agnostic:
                        tmp_cls_boxes = pred_boxes[tmp_inds, :]
                    else:
                        tmp_cls_boxes = pred_boxes[tmp_inds][:, j * 4:(j + 1) * 4]

                    tmp_cls_boxes = tmp_cls_boxes[tmp_order, :]
                    tmp_cls_score = tmp_cls_score[tmp_order]
                    tmp_keep = nms(tmp_cls_boxes, tmp_cls_score, cfg.TEST.NMS)
                    tmp_cls_boxes = tmp_cls_boxes[tmp_keep.view(-1).long(), :]
                    tmp_cls_score = tmp_cls_score[tmp_keep.view(-1).long()]

                    if tmp_cls_boxes.shape[0] != 0:
                        pred_roi = torch.cat([torch.ones_like(tmp_cls_boxes[:,0].unsqueeze(1)) * img_id, tmp_cls_boxes],
                                             dim = 1)
                        pred_rois_list.append(pred_roi)
                        img_rois_list.append(tmp_cls_boxes)
                        pred_roi_label = torch.ones_like(tmp_keep.view(-1)) * j
                        pred_rois_label_list.append(pred_roi_label)
                        pred_rois_conf_list.append(tmp_cls_score)

                # get the context between each object pair
                if use_context and len(img_rois_list) != 0:
                    # get the location of each context
                    img_rois = torch.cat(img_rois_list, dim=0)
                    img_rois_hstack = torch.stack([img_rois] * img_rois.shape[0], dim=0)
                    img_rois_vstack = torch.stack([img_rois] * img_rois.shape[0], dim=1)
                    img_rois_stack = torch.cat([img_rois_vstack, img_rois_hstack], dim=-1)

                    context_x1 = torch.min(
                        torch.stack([img_rois_stack[:,:,0], img_rois_stack[:,:,4]], dim=-1), dim=-1)[0]
                    context_y1 = torch.min(
                        torch.stack([img_rois_stack[:,:,1], img_rois_stack[:,:,5]], dim=-1), dim=-1)[0]
                    context_x2 = torch.max(
                        torch.stack([img_rois_stack[:,:,2], img_rois_stack[:,:,6]], dim=-1), dim=-1)[0]
                    context_y2 = torch.max(
                        torch.stack([img_rois_stack[:,:,3], img_rois_stack[:,:,7]], dim=-1), dim=-1)[0]
                    context_img_id = torch.ones_like(context_x1) * img_id

                    context_rois = torch.stack([context_img_id, context_x1, context_y1, context_x2, context_y2], dim=-1)
                    context_rois = context_rois.view((-1, 5))
                    pred_context_list.append(context_rois)

            pred_rois = torch.cat(pred_rois_list, dim = 0)
            pred_rois_label = torch.cat(pred_rois_label_list, dim = 0)
            pred_rois_conf = torch.cat(pred_rois_conf_list, dim = 0)
            pred_context = torch.cat(pred_context_list, dim=0)

            if cfg.POOLING_MODE == 'align':
                pooled_pred_feat = self.RCNN_roi_align(base_feat, pred_rois)
                if use_context:
                    pooled_context_feat = self.RCNN_roi_align(base_feat, pred_context)
            elif cfg.POOLING_MODE == 'pool':
                pooled_pred_feat = self.RCNN_roi_pool(base_feat, pred_rois)
                if use_context:
                    pooled_context_feat = self.RCNN_roi_pool(base_feat, pred_context)

            # define the outputs
            pooled_pred_feat = self._head_to_tail(pooled_pred_feat)
            output_feat = pooled_pred_feat
            output_label = pred_rois_label
            output_conf = pred_rois_conf
            output_roi = pred_rois

            if use_context:
                pooled_context_feat = self._head_to_tail(pooled_context_feat)
                output_context_feat = pooled_context_feat

        if self.training:
            if use_context:
                return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, \
                       rois_label, output_feat, output_label, output_id, context_dict
            else:
                return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, \
                       rois_label, output_feat, output_label, output_id
        else:
            if use_context:
                return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, \
                       rois_label, output_feat, output_label, output_conf, output_roi, output_context_feat
            else:
                return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, \
                       rois_label, output_feat, output_label, output_conf, output_roi

    def _init_weights(self):
        def normal_init(m, mean, stddev, truncated=False):
            """
            weight initalizer: truncated normal and random normal.
            """
            # x is a parameter
            if truncated:
                m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
            else:
                m.weight.data.normal_(mean, stddev)
                m.bias.data.zero_()

        normal_init(self.RCNN_rpn.RPN_Conv, 0, 0.01, cfg.TRAIN.TRUNCATED)
        normal_init(self.RCNN_rpn.RPN_cls_score, 0, 0.01, cfg.TRAIN.TRUNCATED)
        normal_init(self.RCNN_rpn.RPN_bbox_pred, 0, 0.01, cfg.TRAIN.TRUNCATED)
        normal_init(self.RCNN_cls_score, 0, 0.01, cfg.TRAIN.TRUNCATED)
        normal_init(self.RCNN_bbox_pred, 0, 0.001, cfg.TRAIN.TRUNCATED)

    def create_architecture(self):
        self._init_modules()
        self._init_weights()
