import torch
import torchvision
import numpy as np
from PIL import Image
import torch.nn as nn
from yolo2 import utils
import torch.nn.functional as F
from mmdet.registry import VISUALIZERS
import mmcv

from mmcv.transforms import Compose
from mmcv.transforms import BaseTransform, TRANSFORMS
from mmdet.structures import DetDataSample

# 获取yolo预测结果中指定类别的概率
class MaxProbExtractor(nn.Module):
    """MaxProbExtractor: extracts max class probability for class from YOLO output.

    Module providing the functionality necessary to extract the max class probability for one class from YOLO output.

    """

    def __init__(self, cls_id, num_cls):
        super(MaxProbExtractor, self).__init__()
        self.cls_id = cls_id
        self.num_cls = num_cls
        self.loss_target = lambda obj, cls: obj

    def forward(self, output, gt, loss_type, iou_thresh):
        det_loss = []
        max_probs = []
        num = 0
        for i, boxes in enumerate(output):
            '''
            scores = boxes['scores']
            det_loss.append(scores[0])
            max_probs.append(scores[0])            
            '''
            a = boxes['boxes']

            b = [(gt[i][0] - gt[i][2]/2) * 416,
                     (gt[i][1] - gt[i][3]/2) * 416,
                     (gt[i][2]/2 + gt[i][ 0])*416, (gt[i][3]/2+gt[i][1])*416]
            b_boxes = torch.stack(b, 1)

            ious = torchvision.ops.box_iou(a,b_boxes).squeeze(1)
            mask = ious.ge(iou_thresh)
            if True:
                mask = mask.logical_and(boxes['labels'] == 1)
            ious = ious[mask]
            scores = boxes['scores'][mask]
            if len(ious) > 0:
                if loss_type == 'max_iou':
                    _, ids = torch.max(ious, dim=0)  # get the bbox w/ biggest iou compared to gt
                    det_loss.append(scores[ids])
                    max_probs.append(scores[ids])
                    num += 1
                elif loss_type == 'max_conf':
                    det_loss.append(scores.max())
                    max_probs.append(scores.max())
                    num += 1
                elif loss_type == 'softplus_max':
                    max_conf = - torch.log(1.0 / scores.max() - 1.0)
                    max_conf = F.softplus(max_conf)
                    det_loss.append(max_conf)
                    max_probs.append(scores.max())
                    num += 1
                elif loss_type == 'softplus_sum':
                    max_conf = (F.softplus(- torch.log(1.0 / scores - 1.0)) * ious.detach()).sum()
                    det_loss.append(max_conf)
                    max_probs.append(scores.mean())
                    num += len(scores)

                elif loss_type == 'max_iou_mtiou':
                    _, ids = torch.max(ious, dim=0)  # get the bbox w/ biggest iou compared to gt
                    det_loss.append(scores[ids] * ious[ids])
                    max_probs.append(scores[ids])
                    num += 1
                elif loss_type == 'max_conf_mtiou':
                    _, ids = torch.max(scores, dim=0)
                    det_loss.append(scores[ids] * ious[ids])
                    max_probs.append(scores[ids])
                    num += 1
                elif loss_type == 'softplus_max_mtiou':
                    _, ids = torch.max(scores, dim=0)
                    max_conf = - torch.log(1.0 / scores[ids] - 1.0)
                    max_conf = F.softplus(max_conf) * ious[ids]
                    det_loss.append(max_conf)
                    max_probs.append(scores[ids])
                    num += 1
                elif loss_type == 'softplus_sum_mtiou':
                    max_conf = (F.softplus(- torch.log(1.0 / scores - 1.0)) * ious).sum()
                    det_loss.append(max_conf)
                    max_probs.append(scores.mean())
                    num += len(scores)

                elif loss_type == 'max_iou_adiou':
                    _, ids = torch.max(ious, dim=0)  # get the bbox w/ biggest iou compared to gt
                    det_loss.append(scores[ids] + ious[ids])
                    max_probs.append(scores[ids])
                    num += 1
                elif loss_type == 'max_conf_adiou':
                    _, ids = torch.max(scores, dim=0)
                    det_loss.append(scores[ids] + ious[ids])
                    max_probs.append(scores[ids])
                    num += 1
                elif loss_type == 'softplus_max_adiou':
                    _, ids = torch.max(scores, dim=0)
                    max_conf = - torch.log(1.0 / scores[ids] - 1.0)
                    max_conf = F.softplus(max_conf) + ious[ids]
                    det_loss.append(max_conf)
                    max_probs.append(scores[ids])
                    num += 1
                elif loss_type == 'softplus_sum_adiou':
                    max_conf = (F.softplus(- torch.log(1.0 / scores - 1.0)) + ious).sum()
                    det_loss.append(max_conf)
                    max_probs.append(scores.mean())
                    num += len(scores)

                elif loss_type == 'softplus_max_adspiou':
                    _, ids = torch.max(scores, dim=0)
                    max_conf = - torch.log(1.0 / scores[ids] - 1.0)
                    max_conf = F.softplus(max_conf) + F.softplus(- torch.log(1.0 / ious[ids] - 1.0))
                    det_loss.append(max_conf)
                    max_probs.append(scores[ids])
                    num += 1
                elif loss_type == 'softplus_sum_adspiou':
                    max_conf = (F.softplus(- torch.log(1.0 / scores - 1.0)) + F.softplus(
                        - torch.log(1.0 / ious - 1.0))).sum()
                    det_loss.append(max_conf)
                    max_probs.append(scores.mean())
                    num += len(scores)

                else:
                    raise ValueError
            else:
                det_loss.append(ious.new([0.0])[0])
                max_probs.append(ious.new([0.0])[0])

        det_loss = torch.stack(det_loss).mean()
        max_probs = torch.stack(max_probs)
        if num < 1:
            raise RuntimeError()
        return det_loss, max_probs

def truths_length(truths):
    for i in range(len(truths)):
        if truths[i][1] == -1:
            return i
    return len(truths)


def get_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, only_objectness=1, validation=False,
                     name=None):
    anchor_step = len(anchors) // num_anchors
    device = output.device
    if output.dim() == 3:
        output = output.unsqueeze(0)
    batch = output.size(0)
    assert (output.size(1) == (5 + num_classes) * num_anchors)
    h = output.size(2)
    w = output.size(3)

    output = output.view(batch * num_anchors, 5 + num_classes, h * w)
    output = output.transpose(0, 1).contiguous()
    output = output.view(5 + num_classes, batch * num_anchors * h * w)
    # grid_x = torch.linspace(0, w-1, w).repeat(h,1).repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).to(output)
    # grid_y = torch.linspace(0, h-1, h).repeat(w,1).t().repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).to(output)
    grid_y, grid_x = torch.meshgrid([torch.arange(w, device=device), torch.arange(h, device=device)])
    grid_x = grid_x.repeat(batch * num_anchors, 1, 1).flatten()
    grid_y = grid_y.repeat(batch * num_anchors, 1, 1).flatten()
    xs = torch.sigmoid(output[0]) + grid_x
    ys = torch.sigmoid(output[1]) + grid_y

    anchor_tensor = torch.tensor(anchors, device=device).view(num_anchors, anchor_step)
    # anchor_w = anchor_tensor.index_select(1, torch.LongTensor([0]))
    # anchor_h = anchor_tensor.index_select(1, torch.LongTensor([1]))
    anchor_w = anchor_tensor[:, 0:1]
    anchor_h = anchor_tensor[:, 1:2]
    anchor_w = anchor_w.repeat(batch, 1).repeat(1, 1, h * w).view(batch * num_anchors * h * w)
    anchor_h = anchor_h.repeat(batch, 1).repeat(1, 1, h * w).view(batch * num_anchors * h * w)
    ws = torch.exp(output[2]) * anchor_w
    hs = torch.exp(output[3]) * anchor_h

    det_confs = torch.sigmoid(output[4])
    # cls_confs = torch.nn.Softmax()(Variable(output[5:5+num_classes].transpose(0,1))).data

    if name == 'yolov2':
        cls_confs = output[5:5 + num_classes].transpose(0, 1).softmax(-1)
    elif name == 'yolov3':
        cls_confs = output[5:5 + num_classes].transpose(0, 1).sigmoid()
    else:
        raise ValueError

    cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)
    cls_max_confs = cls_max_confs.view(-1)
    cls_max_ids = cls_max_ids.view(-1)

    raw_boxes = torch.stack([xs/w, ys/h, ws/w, hs/h, det_confs, cls_max_confs, cls_max_ids], 1).view(batch, -1, 7)
    if only_objectness:
        conf = det_confs
    else:
        conf = det_confs * cls_max_confs
    inds = (conf > conf_thresh).view(batch, -1)

    all_boxes = [b[i] for b, i in zip(raw_boxes, inds)]

    if (not only_objectness) and validation:
        raise NotImplementedError
    return all_boxes


def xyxy2xywh(x):
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
    y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
    y[:, 2] = x[:, 2] - x[:, 0]  # width
    y[:, 3] = x[:, 3] - x[:, 1]  # height
    return y


def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y


def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
    # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw  # top left x
    y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh  # top left y
    y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw  # bottom right x
    y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh  # bottom right y
    return y


def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w  # x center
    y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h  # y center
    y[:, 2] = (x[:, 2] - x[:, 0]) / w  # width
    y[:, 3] = (x[:, 3] - x[:, 1]) / h  # height
    return y


def xyn2xy(x, w=640, h=640, padw=0, padh=0):
    # Convert normalized segments into pixel segments, shape (n,2)
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = w * x[:, 0] + padw  # top left x
    y[:, 1] = h * x[:, 1] + padh  # top left y
    return y

@TRANSFORMS.register_module()
class MyLoadImage(BaseTransform):
    def __init__(self, meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')):
        super().__init__()
        self.meta_keys = meta_keys

    def transform(self, results: dict) -> dict:
        pack_results = dict()
        pack_results['inputs'] = results["img"]
        pa_data_sample = []
        batch = results["img"].size()[0]
        for i in range(batch):
            son_data_sample = DetDataSample()
            img_meta = {}
            for key in self.meta_keys:
                assert key in results, f'`{key}` is not found in `results`, ' \
                    f'the valid keys are {list(results)}.'
                img_meta[key] = results[key]
            son_data_sample.set_metainfo(img_meta)
            pa_data_sample.append(son_data_sample)
        pack_results['data_samples'] = pa_data_sample
        return pack_results

def prepare_data(image, img_size):
    data = dict(img = image, img_id = 0, img_path = '', ori_shape = (img_size, img_size), img_shape = (640, 640),
                scale_factor = np.array([640 / img_size, 640 / img_size], dtype=np.float32))
    test_pipeline = [dict(type="MyLoadImage")]
    test_pipeline = Compose(test_pipeline)
    data = test_pipeline(data)
    return data


def get_rcnn_loss(darknet_model, p_img, lab_batch, args, kwargs):
    valid_num = 0
    det_loss = p_img.new_zeros([])
    output = [darknet_model(p_img)]

    for all_boxes in output:
        for ii in range(p_img.shape[0]):
            if all_boxes[ii]['boxes'].shape[0] > 0:
                idx = (all_boxes [ii]['labels'] == 1)
                new_scores = all_boxes [ii]['scores'][idx]
                '''
                tmp_bboxes = all_boxes[ii]['boxes'][idx]
                a = tmp_bboxes/416.0  # xyxy2xywhn(tmp_bboxes, w=416, h=416)
                b = lab_batch[ii][:truths_length(lab_batch[ii]), 1:]
                iou_mat = utils.bbox_iou_mat(a,b,False)
                iou_max = iou_mat.max(1)[0]
                idxs = iou_max > args.iou_thresh
                '''
                idxs = new_scores > args.conf_thresh
                det_confs = new_scores[idxs] 
         
                if det_confs.shape[0] > 0:
                    max_prob = det_confs.max()  
                    det_loss = det_loss + max_prob  
                    valid_num += 1
    return det_loss, valid_num

def get_detr_loss(darknet_model, p_img, lab_batch, args, kwargs):
    valid_num = 0
    det_loss = 0
    p_img = prepare_data(p_img * 255.0, p_img.size()[2])
    results = darknet_model.test_step(p_img)
    for i in range(len(results)):
        mask = [results[i].pred_instances.labels == 0]
        scores = results[i].pred_instances.scores[mask] 
        if scores.numel() != 0:
            # max_prob = scores.max()
            max_prob = scores.sum()
            valid_num += 1
            det_loss += max_prob
    return det_loss, valid_num



def get_v2_loss(train, darknet_model, p_img, lab_batch, args, kwargs):
    valid_num = 0
    det_loss = p_img.new_zeros([])
    margin = 1
    output = [darknet_model(p_img)]
    if kwargs['name'] == 'yolov2':
        all_boxes_t = [utils.get_region_boxes_general(train, output, darknet_model, conf_thresh=args.conf_thresh, name=kwargs['name'])]
    else:
        raise ValueError
 
    for all_boxes in all_boxes_t:
        for ii in range(p_img.shape[0]):
            if all_boxes[ii].shape[0] > 0:
                a = all_boxes[ii][..., :4]
                b = lab_batch[ii][:truths_length(lab_batch[ii]), 1:]
                iou_mat = utils.bbox_iou_mat(a, b, False)
                iou_max = iou_mat.max(1)[0]
                idxs = iou_max > args.iou_thresh
                det_confs = all_boxes[ii][idxs][:, 4]  
                tar_cls_scores = all_boxes[ii][idxs][:, 67]
            
                if det_confs.shape[0] > 0:
                 
                    max_prob = det_confs.max()         
                    det_loss = det_loss + max_prob      
                    

                    # tar_cls_max_score = tar_cls_scores.max()  
                    # all_boxes[ii][idxs][:, 67] = 0 
                    # det_loss = det_loss - tar_cls_max_score + torch.max(all_boxes[ii][idxs][:, 5:]) + margin
                    
                    valid_num += 1

    return det_loss, valid_num
def get_v3_loss(darknet_model, p_img, lab_batch, args, kwargs):
    # valid_num = 0
    # det_loss = p_img.new_zeros([])
    # output = darknet_model(p_img)
    # boxes = []
    # for o, yl in zip(output, darknet_model.yolo_layers):
    #     B, A, W, H, C = o.shape
    #     b = get_region_boxes(o.permute(0, 1, 4, 2, 3).contiguous().view(B, A * C, W, H), args.conf_thresh,
    #                          yl.num_classes, yl.anchors.flatten().div(yl.stride).tolist(), yl.num_anchors,
    #                          name='yolov3')
    #     boxes.append(b)
    # all_boxes_t = [[torch.cat([boxes[i][j] for i in range(len(output))], 0) for j in range(output[0].shape[0])]]

    # for all_boxes in all_boxes_t:
    #     for ii in range(p_img.shape[0]):
    #         if all_boxes[ii].shape[0] > 0:
    #             a = all_boxes[ii][..., :4]
    #             b = lab_batch[ii][:truths_length(lab_batch[ii]), 1:]
    #             iou_mat = utils.bbox_iou_mat(a, b, False)
    #             iou_max = iou_mat.max(1)[0]
    #             idxs = iou_max > args.iou_thresh
    #             det_confs = all_boxes[ii][idxs][:, 4] 
    #            
    #             if det_confs.shape[0] > 0:
    #                 max_prob = det_confs.mean()         
    #                 det_loss = det_loss + max_prob     
    #                 valid_num += 1

    valid_num = 0
    det_loss = 0
    p_img = prepare_data(p_img * 255.0, p_img.size()[2])
    results = darknet_model.test_step(p_img)
    for i in range(len(results)):
        if results[i].pred_instances.scores.numel() != 0:
            max_prob = results[i].pred_instances.scores.max()
            valid_num += 1
            det_loss += max_prob

    return det_loss, valid_num


def get_v5_loss(darknet_model, p_img, lab_batch, args, kwargs):

    valid_num = 0
    det_loss = p_img.new_zeros([])
    output = [darknet_model(p_img)[0]]
    # print(len(output))
    # print(len(output[0][0]))
    # exit()
    margin = 1
    for all_boxes in output:
        for ii in range(p_img.shape[0]):
            if all_boxes[ii].shape[0] > 0:
                idx = all_boxes[ii][...,4]>args.conf_thresh
                bboxes = all_boxes[ii][idx][..., :4]
                new_pre = all_boxes[ii][idx]
                # a = bboxes/416.0
                a = bboxes / 640.0
                b = lab_batch[ii][:truths_length(lab_batch[ii]), 1:]
                iou_mat = utils.bbox_iou_mat(a, b, False)
                iou_max = iou_mat.max(1)[0]
                idxs = iou_max > args.iou_thresh
                det_confs = new_pre[idxs][:, 4] 
                tar_cls_scores = new_pre[idxs][:, 67]
                ori_cls_scores = new_pre[idxs][:, 12]

                if det_confs.shape[0] > 0:
                    max_prob = det_confs.max() 
                    det_loss = det_loss + max_prob 
                 
                    valid_num += 1
    return det_loss, valid_num



def gauss_kernel(ksize=5, sigma=None, conv=False, dtype=np.float32):
    half = (ksize - 1) * 0.5
    if sigma is None:
        sigma = 0.3 * (half - 1) + 0.8
    x = np.arange(-half, half + 1)
    x = np.exp(- np.square(x / sigma) / 2)
    x = np.outer(x, x)
    x = x / x.sum()
    if conv:
        kernel = np.zeros((3, 3, ksize, ksize))
        for i in range(3):
            kernel[i, i] = x
    else:
        kernel = x
    return kernel.astype(dtype)


def pad_and_scale(img, lab=None, size=(416, 416), color=(127, 127, 127)):
    w, h = img.size
    if w == h:
        padded_img = img
    else:
        dim_to_pad = 1 if w < h else 2
        if dim_to_pad == 1:
            padding = (h - w) / 2
            padded_img = Image.new('RGB', (h, h), color=color)
            padded_img.paste(img, (int(padding), 0))
            if lab is not None:
                lab[:, [1]] = (lab[:, [1]] * w + padding) / h
                lab[:, [3]] = (lab[:, [3]] * w / h)
        else:
            padding = (w - h) / 2
            padded_img = Image.new('RGB', (w, w), color=color)
            padded_img.paste(img, (0, int(padding)))
            if lab is not None:
                lab[:, [2]] = (lab[:, [2]] * h + padding) / w
                lab[:, [4]] = (lab[:, [4]] * h / w)
    padded_img = padded_img.resize((size[0], size[1]))
    if lab is None:
        return padded_img
    else:
        return padded_img, lab

def TVLoss(patch):
    t1 = (patch[:, :, 1:, :] - patch[:, :, :-1, :]).abs().sum()
    t2 = (patch[:, :, :, 1:] - patch[:, :, :, :-1]).abs().sum()
    tv = t1 + t2
    return tv / patch.numel()