import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Decoupled_module(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        feature_dim = model_cfg.FEATURE_DIM
        self.num_class = 1
        self.decouple_layers = self.make_fc_layers(
            input_channels=feature_dim, output_channels=feature_dim, fc_list=self.model_cfg.DECOUPLE_FC # [64]
        )
        self.cls_layers = self.make_fc_layers(
            input_channels=feature_dim, output_channels=self.num_class, fc_list=self.model_cfg.CLS_FC
        )
        self.forward_ret_dict = None

    def forward(self, data_dict):
        target_dict = {}
        roi_shared_feature = data_dict['roi_shared_feature']
        b, n, c = roi_shared_feature.shape
        roi_shared_feature = roi_shared_feature.view(b*n, c, -1)
        decoupled_feature = self.decouple_layers(roi_shared_feature)
        cls_pred_decouple = self.cls_layers(decoupled_feature).transpose(1, 2).contiguous().squeeze(dim=1)
        data_dict['roi_decouple_feature'] = decoupled_feature.view(b, n, c)
        target_dict['decouple_cls'] = cls_pred_decouple
        target_dict['rcnn_cls_labels'] = data_dict['rcnn_cls_labels']
        self.forward_ret_dict = target_dict
        return data_dict
    
    def get_box_cls_layer_loss(self):
        loss_cfgs = self.model_cfg.LOSS_CONFIG
        rcnn_cls = self.forward_ret_dict['decouple_cls']
        rcnn_cls_labels = self.forward_ret_dict['rcnn_cls_labels'].view(-1)
        if loss_cfgs.CLS_LOSS == 'BinaryCrossEntropy':
            rcnn_cls_flat = rcnn_cls.view(-1)
            batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), rcnn_cls_labels.float(), reduction='none')
            cls_valid_mask = (rcnn_cls_labels >= 0).float()
            decouple_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
        elif loss_cfgs.CLS_LOSS == 'CrossEntropy':
            batch_loss_cls = F.cross_entropy(rcnn_cls, rcnn_cls_labels, reduction='none', ignore_index=-1)
            cls_valid_mask = (rcnn_cls_labels >= 0).float()
            decouple_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
        else:
            raise NotImplementedError

        tb_dict = {'decouple_loss_cls': decouple_loss_cls.item()}
        return decouple_loss_cls, tb_dict
    
    def get_loss(self, pred_dict, tb_dict=None):
        tb_dict = {} if tb_dict is None else tb_dict
        decouple_cls_loss, cls_tb_dict = self.get_box_cls_layer_loss()
        tb_dict.update(cls_tb_dict)
        # loss_contrast = get_contrastive_loss(pred_dict, self.model_cfg['CONTRASTIVE_LOSS_CFG'])
        # loss_decouple = decouple_cls_loss + loss_contrast
        loss_decouple = decouple_cls_loss
        return loss_decouple, tb_dict
        

    def get_contrastiveloss(self, pred_dict):
        loss = get_contrastive_loss(pred_dict, self.model_cfg['CONTRASTIVE_LOSS_CFG'])
        return loss
    
    def make_fc_layers(self, input_channels, output_channels, fc_list):
        fc_layers = []
        pre_channel = input_channels
        for k in range(0, fc_list.__len__()):
            fc_layers.extend([
                nn.Conv1d(pre_channel, fc_list[k], kernel_size=1, bias=False),
                nn.BatchNorm1d(fc_list[k]),
                nn.ReLU()
            ])
            pre_channel = fc_list[k]
            if self.model_cfg.DP_RATIO >= 0 and k == 0:
                fc_layers.append(nn.Dropout(self.model_cfg.DP_RATIO))
        fc_layers.append(nn.Conv1d(pre_channel, output_channels, kernel_size=1, bias=True))
        fc_layers = nn.Sequential(*fc_layers)
        return fc_layers


class InfoNCELoss(nn.Module):
    def __init__(self, loss_cfg):
        super().__init__()
        self.T = loss_cfg['TEMPERATURE']
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, pos_feature_1, pos_feature_2, neg_feature):
        N, C = pos_feature_1.shape
        l_pos = torch.bmm(pos_feature_1.view(N, 1, C), pos_feature_2.view(N, C, 1)).view(N, 1) # Nx1
        if neg_feature.shape == pos_feature_1.shape: # paired
            l_neg = torch.bmm(pos_feature_1.view(N, 1, C), neg_feature.view(N, C, 1)).view(N, 1) # Nx1 -> need to support NxK
        else:
            l_neg = torch.bmm(pos_feature_1.view(N, 1, C), neg_feature.view(N, C, -1)).view(N, -1)
        logits = torch.cat([l_pos, l_neg], dim=1).view(N, -1)
        labels = logits.new_zeros(logits.shape[0]).long()
        loss = self.criterion(logits / self.T, labels)
        return loss

def split_instances_by_distance(boxes, pred_inds, loss_cfg):
    distance_split_labels = np.zeros(boxes.shape[0])
    boxes = boxes.detach().cpu().numpy()
    distance = np.sqrt(np.sum(boxes[:, 0:3] * boxes[:, 0:3], axis=1))
    # TODO add distance_split: [0, 30, 50]
    split_distance = loss_cfg['SPLIT_DISTANCE']
    split_distance_dict = {}
    for i in range(len(split_distance)):
        if i < len(split_distance) - 1:
            dis_mask = (distance >= split_distance[i]) & (distance < split_distance[i+1])
            split_distance_dict.update({
                'distance_label_' + str(i): np.array(pred_inds)[dis_mask].tolist()
            })
        else:
            dis_mask = distance > split_distance[i]
            split_distance_dict.update({
                'distance_label_' + str(i): np.array(pred_inds)[dis_mask].tolist()
            })
        distance_split_labels[dis_mask] = i
    return distance_split_labels, split_distance_dict

def get_contrastive_loss(pred_dict, loss_cfg):
    '''
    get contrastive loss
    '''
    pos_feature_1, pos_feature_2, neg_feature = sample_contrastive_pairs(pred_dict, loss_cfg)
    if pos_feature_1 == None:
        assert pos_feature_2 == None and neg_feature == None
        loss = 0.
        return loss
    loss_type = loss_cfg['NAME']
    if loss_type == 'InfoNCELoss':
        criterion = InfoNCELoss(loss_cfg)
        loss = criterion(pos_feature_1, pos_feature_2, neg_feature)
    else:
        raise NotImplementedError
    return loss

def sample_contrastive_pairs(pred_dict, loss_cfg):
    '''
    sample range-aware positive and nagetive pair
    need to do:
        1. range-aware positive and negative
        2. let positive instance be a unified representation (prototype)
        3. in-scene copy and paste
    '''
    # input data_dict: roi feature and predicted bounding boxes and corresponding iou confidence
    # spilt instances by distance and iou thresshold
    # construct positive pair and negtive pair

    batch_size = len(pred_dict)
    positive_feature_1 = None
    positive_feature_2 = None
    negative_feature = None
    for index in range(batch_size):
        pred_boxes_ori = pred_dict[index]['pred_boxes']
        pred_scores_ori = pred_dict[index]['pred_scores']
        pred_labels_ori = pred_dict[index]['pred_labels']
        roi_feature_ori = pred_dict[index]['select_roi_decouple'] # select_roi_decouple change select_roi
        
        for i, cur_class in enumerate(loss_cfg.CLASS):
            cur_loss_cfg = loss_cfg[cur_class]
            cls_mask = pred_labels_ori == (i + 1)
            pred_boxes = pred_boxes_ori[cls_mask]
            pred_scores = pred_scores_ori[cls_mask]
            pred_labels = pred_labels_ori[cls_mask]
            roi_feature = roi_feature_ori[cls_mask]

            pred_inds = list(range(pred_labels.size(0)))
            distance_split_labels, distance_split_dict = split_instances_by_distance(pred_boxes, pred_inds, loss_cfg)
            
            # sample positive pairs
            positive_mask = (pred_scores > cur_loss_cfg['POS_CFG']['POS_THR']).detach().cpu().numpy()
            if cur_loss_cfg['POS_CFG']['DISTANCE_AWARE']:
                # positive_distance_labels = distance_split_labels[positive_mask]
                # distance2index = positive_distance_labels
                pos_inds = np.array(pred_inds)[positive_mask].tolist()
                if len(pos_inds) == 0:
                    continue
                pos_pair_inds = np.random.randint(0, len(pos_inds), (1, cur_loss_cfg['POS_CFG']['POS_NUM'])).tolist()
                pos_feature_1 = roi_feature[pos_inds][pos_pair_inds[0]]
                distance_split_dict_pos = {}
                for i in range(len(loss_cfg['SPLIT_DISTANCE'])):
                    cur_distance_inds = distance_split_dict['distance_label_'+str(i)]
                    cur_pos_mask = (pred_scores[cur_distance_inds] > cur_loss_cfg['POS_CFG']['POS_THR']).cpu().tolist()
                    cur_distance_label = np.array(cur_distance_inds)[cur_pos_mask].tolist()
                    cur_distance_label = [cur_distance_label] if not isinstance(cur_distance_label, list) else cur_distance_label
                    distance_split_dict_pos.update({
                        'distance_label' + str(i): cur_distance_label
                    })

                pos_distance_label = distance_split_labels[pos_inds][pos_pair_inds[0]]
                pos_feature_2 = pos_feature_1.new_zeros(pos_feature_1.shape)
                for i in range(len(loss_cfg['SPLIT_DISTANCE'])):
                    cur_distance_pos_mask = pos_distance_label == i
                    cur_distance_num_pos = sum(cur_distance_pos_mask)
                    cur_pos_inds = list(range(len(pos_distance_label)))
                    cur_pos_inds_2 = distance_split_dict_pos['distance_label'+str(i)]
                    cur_distance_num_pos_2 = len(cur_pos_inds_2)
                    if cur_distance_num_pos == 0:
                        continue
                    pos_sample_inds = np.random.randint(0, cur_distance_num_pos_2, cur_distance_num_pos).tolist()
                    pos_feature_2[np.array(cur_pos_inds)[cur_distance_pos_mask].tolist()] = roi_feature[cur_pos_inds_2][pos_sample_inds]

            else:
                pos_inds = np.array(pred_inds)[positive_mask].tolist()
                if len(pos_inds) == 0:
                    continue
                pos_pair_inds = np.random.randint(0, len(pos_inds), (2, cur_loss_cfg['POS_CFG']['POS_NUM'])).tolist()
                pos_feature_1 = roi_feature[pos_inds][pos_pair_inds[0]]
                pos_feature_2 = roi_feature[pos_inds][pos_pair_inds[1]]
                # sort positive feature by distance
            # sample negative pairs
            negative_mask = (pred_scores < cur_loss_cfg['NEG_CFG']['NEG_THR']).cpu()
            if sum(negative_mask) == 0:
                continue
            if cur_loss_cfg['NEG_CFG']['DISTANCE_AWARE']:
                distance_split_dict_neg = {}
                for i in range(len(loss_cfg['SPLIT_DISTANCE'])):
                    cur_distance_inds = distance_split_dict['distance_label_'+str(i)]
                    cur_neg_mask = (pred_scores[cur_distance_inds] < cur_loss_cfg['NEG_CFG']['NEG_THR']).cpu().tolist()
                    cur_distance_label = np.array(cur_distance_inds)[cur_neg_mask].tolist()
                    cur_distance_label = [cur_distance_label] if not isinstance(cur_distance_label, list) else cur_distance_label
                    distance_split_dict_neg.update({
                        'distance_label' + str(i): cur_distance_label
                    })

                pos_distance_label = distance_split_labels[pos_pair_inds[0]]
                neg_feature = pos_feature_1.new_zeros(pos_feature_1.shape)
                for i in range(len(loss_cfg['SPLIT_DISTANCE'])):
                    cur_distance_pos_mask = pos_distance_label == i
                    cur_distance_num_pos = sum(cur_distance_pos_mask)
                    cur_pos_inds = list(range(len(pos_distance_label)))
                    cur_neg_inds = distance_split_dict_neg['distance_label'+str(i)]
                    cur_distance_num_neg = len(cur_neg_inds)
                    if cur_distance_num_neg == 0:
                        continue
                    neg_sample_inds = np.random.randint(0, cur_distance_num_neg, cur_distance_num_pos).tolist()
                    neg_feature[np.array(cur_pos_inds)[cur_distance_pos_mask].tolist()] = roi_feature[cur_neg_inds][neg_sample_inds]
            else:
                neg_inds = np.array(pred_inds)[negative_mask].tolist()
                # enable multiple neg feature
                neg_pair_inds = np.random.randint(0, len(neg_inds), (1, cur_loss_cfg['POS_CFG']['POS_NUM']*cur_loss_cfg['NEG_CFG']['NEG_NUM_PER_ANCHOR'])).tolist()
                neg_feature = roi_feature[neg_inds][neg_pair_inds[0]].view(cur_loss_cfg['POS_CFG']['POS_NUM'], cur_loss_cfg['NEG_CFG']['NEG_NUM_PER_ANCHOR'], -1)

            if positive_feature_1 == None:
                assert positive_feature_2 == None and negative_feature == None
                positive_feature_1, positive_feature_2, negative_feature = pos_feature_1, pos_feature_2, neg_feature
            else:
                positive_feature_1 = torch.cat([positive_feature_1, pos_feature_1], dim=0)
                positive_feature_2 = torch.cat([positive_feature_2, pos_feature_2], dim=0)
                negative_feature = torch.cat([negative_feature, neg_feature], dim=0)
    return positive_feature_1, positive_feature_2, negative_feature
