import numpy as np
import torch
import torch.nn.functional as F
from torch import nn

from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.meta_arch import META_ARCH_REGISTRY
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.modeling.proposal_generator import build_proposal_generator

# from detectron2.modeling.roi_heads import build_roi_heads
from detectron2.structures import ImageList
from detectron2.utils.logger import log_first_n

import logging
import math

from ..self_supervised import build_ss_head
from ..roi_heads import build_roi_heads
import pdb
from torchvision.transforms import functional as F1
import torchvision.transforms as T

__all__ = ["SSRCNN"]


@META_ARCH_REGISTRY.register()
class SSRCNN(nn.Module):
    """
    Detection + self-supervised
    """

    def __init__(self, cfg):
        super().__init__()
        # pylint: disable=no-member
        self.device = torch.device(cfg.MODEL.DEVICE)
        self.backbone = build_backbone(cfg)
        self.proposal_generator = build_proposal_generator(
            cfg, self.backbone.output_shape()
        )
        self.from_config(cfg)
        self.roi_heads = build_roi_heads(cfg, self.backbone.output_shape())
        '''
        self.ss_head = build_ss_head(
            cfg, self.backbone.bottom_up.output_shape()
        )

        for i in range(len(self.ss_head)):
            setattr(self, "ss_head_{}".format(i), self.ss_head[i])
        '''
        self.to(self.device)

    def from_config(self, cfg):
        # only train/eval the ss branch for debugging.
        self.ss_only = cfg.MODEL.SS.ONLY
        self.feat_level = cfg.MODEL.SS.FEAT_LEVEL  # res4

        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
        num_channels = len(cfg.MODEL.PIXEL_MEAN)
        pixel_mean = (
            torch.Tensor(cfg.MODEL.PIXEL_MEAN)
            .to(self.device)
            .view(num_channels, 1, 1)
        )
        pixel_std = (
            torch.Tensor(cfg.MODEL.PIXEL_STD)
            .to(self.device)
            .view(num_channels, 1, 1)
        )
        self.normalizer = lambda x: (x - pixel_mean) / pixel_std

    def forward(self, batched_inputs):
        """
        Training methods, which jointly train the detector and the
        self-supervised task.
        """
        if not self.training:
            return self.inference(batched_inputs)
        losses = {}
        accuracies = {}
        '''
        # torch.save(batched_inputs, "inputs.pt")
        for i in range(len(self.ss_head)):
            """Using images as the input to SS tasks."""
            head = getattr(self, "ss_head_{}".format(i))
            if head.input != "images":
                continue
            out, tar, ss_losses = head(
                batched_inputs, self.backbone.bottom_up, self.feat_level
            )  # attach new parameters
            losses.update(ss_losses)
            acc = (out.argmax(axis=1) == tar).float().mean().item() * 100
            accuracies["accuracy_ss_{}".format(head.name)] = {
                "accuracy": acc,
                "num": len(tar),
            }
        '''
        # for detection part
        images = self.preprocess_image(batched_inputs)
        if "instances" in batched_inputs[0]:
            gt_instances = [
                x["instances"].to(self.device) for x in batched_inputs
            ]
            #gt_instances = [gt_instances[0], gt_instances[0]]
        else:
            gt_instances = None
        # print(images.tensor.size(), images.image_sizes)
        features, mag_ls, stat_ls = self.backbone(images.tensor)
        #for cnt in [0, 3, 7]:
        '''
        all_loss = []
        for cnt, mag in enumerate(mag_ls):
            mean, std = stat_ls[cnt]
            mean_diff = torch.abs(mean[1] - mean[0]) / torch.max(torch.abs(mean), 0).values
            mean_target = mean_diff > torch.median(mean_diff) # style

            std_diff = torch.abs(std[1] - std[0]) / torch.max(torch.abs(mean), 0).values
            diff_target = std_diff > torch.median(std_diff) # style
            target = mean_target * diff_target # 1 for style, 0 for content
            target = 1 - target.to(torch.float32) # 1 for content, 0 for style
            target = torch.cat([target.unsqueeze(0), target.unsqueeze(0)], 0) 
            style_loss = F.binary_cross_entropy_with_logits(mag, target)
            all_loss.append(style_loss)
            del mean
            del std
            del mean_diff
            del std_diff
        '''
        del mag_ls
        del stat_ls

        # print(features['p2'].size(),features['p3'].size(), features['p4'].size(), features['p5'].size(), features['p6'].size())
        if self.proposal_generator:
            proposals, proposal_losses = self.proposal_generator(
                images, features, gt_instances
            )
        else:
            assert "proposals" in batched_inputs[0]
            proposals = [
                x["proposals"].to(self.device) for x in batched_inputs
            ]
            proposal_losses = {}
        # print(len(proposals), proposals[0])

        _, detector_losses = self.roi_heads(
            images, features, proposals, gt_instances
        )
        '''
        if isinstance(detector_losses, tuple):
            detector_losses, box_features = detector_losses

            for i in range(len(self.ss_head)):
                head = getattr(self, "ss_head_{}".format(i))
                if head.input != "ROI":
                    continue
                # during training, the paired of inputs are put in one batch
                ss_losses, acc = head(box_features)
                losses.update(ss_losses)
                accuracies["accuracy_ss_{}".format(head.name)] = {
                    "accuracy": acc,
                    "num": 1,
                }
        '''
        losses.update(detector_losses)
        losses.update(proposal_losses)
        #losses['loss_style'] = sum(all_loss) / len(all_loss) * 0.05 #* 0.1

        for k, v in losses.items():
            assert math.isnan(v) == False, batched_inputs

        return losses

    def det_inference(
        self, batched_inputs, detected_instances=None, do_postprocess=True
    ):
        """
        Run inference on the given inputs.

        Args:
            batched_inputs (list[dict]): same as in :meth:`forward`
            detected_instances (None or list[Instances]): if not None, it
                contains an `Instances` object per image. The `Instances`
                object contains "pred_boxes" and "pred_classes" which are
                known boxes in the image.
                The inference will then skip the detection of bounding boxes,
                and only predict other per-ROI outputs.
            do_postprocess (bool): whether to apply post-processing on the outputs.

        Returns:
            same as in :meth:`forward`.
        """
        assert not self.training

        images = self.preprocess_image(batched_inputs)
        features, _, _ = self.backbone(images.tensor)

        if detected_instances is None:
            if self.proposal_generator:
                proposals, _ = self.proposal_generator(images, features, None)
            else:
                assert "proposals" in batched_inputs[0]
                proposals = [
                    x["proposals"].to(self.device) for x in batched_inputs
                ]

            results, others = self.roi_heads(images, features, proposals, None)
            if isinstance(others, tuple):
                others, box_features = others

            else:
                box_features = None
        else:
            detected_instances = [
                x.to(self.device) for x in detected_instances
            ]
            results = self.roi_heads.forward_with_given_boxes(
                features, detected_instances
            )
            box_features = None
        # pdb.set_trace()
        if do_postprocess:
            processed_results = []
            for results_per_image, input_per_image, image_size in zip(
                results, batched_inputs, images.image_sizes
            ):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                r = detector_postprocess(results_per_image, height, width)
                processed_results.append({"instances": r})
            # xxxx = torch.cat((processed_results[0]['instances'].pred_boxes.tensor, processed_results[0]['instances'].pred_classes.unsqueeze(dim=1)), dim=1)
            # with open('tensor.txt', 'w') as f:
            #     for row in xxxx:
            #         f.write(' '.join([str(elem) for elem in row.tolist()]) + '\n')
            # pdb.set_trace()
            return processed_results, box_features
        else:
            return results, box_features

    def inference(
        self, batched_inputs, detected_instances=None, do_postprocess=True
    ):
        """ used for standard detectron2 test method"""
        results, _ = self.det_inference(
            batched_inputs, detected_instances, do_postprocess
        )
        return results

    def preprocess_image(self, batched_inputs):
        """normalize, pad and batch the input images"""
        if self.training:
            #assert len(batched_inputs) == 1
            images = [x["image"].to(self.device) for x in batched_inputs]
            images = [self.normalizer(x) for x in images]
            #new_images = [x["new_image"].to(self.device) for x in batched_inputs]
            #new_images = [self.normalizer(x) for x in new_images]
            #images = [images[0], new_images[0]]
            images = ImageList.from_tensors(
                images, self.backbone.size_divisibility
            )
            return images
        else:
            images = [x["image"].to(self.device) for x in batched_inputs]
            
            
            # 额外增强******************************************
            # imgx = images[0]
            # # pdb.set_trace()
            # if imgx.shape[0] == 3:
            #     imgx_gray = F1.rgb_to_grayscale(imgx)
            # else:
            #     imgx_gray = imgx
            # # Define the Gaussian kernel
            # kernel_size = 3
            # sigma = 0.2
            # # 对每个通道应用高斯核
            # blur = T.GaussianBlur(kernel_size,sigma)
            # low_freq_rgb = blur(imgx.unsqueeze(0))
            # # pdb.set_trace()
            # low_freq_rgb = F.interpolate(low_freq_rgb, size=(imgx_gray.shape[1],imgx_gray.shape[2]), mode='bilinear', align_corners=False).squeeze(0)
            # if low_freq_rgb.shape[0] == 3:
            #     low_img_gray = F1.rgb_to_grayscale(low_freq_rgb)
            # else:
            #     low_img_gray = low_freq_rgb
            # # pdb.set_trace()
            # imgx1 = imgx_gray - low_img_gray
            # # pdb.set_trace()
            # # pdb.set_trace()
            # result_array = imgx1.repeat(3,1,1)
            # # result_array = low_img_gray.repeat(3,1,1)
            # images[0] = images[0] + result_array
            # 额外增强******************************************
            
            
            images = [self.normalizer(x) for x in images]
            images = ImageList.from_tensors(
                images, self.backbone.size_divisibility
            )
            return images


def instance_whitening_loss(f_map, eye, mask_matrix, margin, num_remove_cov):
    f_cor, B = get_covariance_matrix(f_map, eye=eye)
    f_cor_masked = f_cor * mask_matrix

    off_diag_sum = torch.sum(torch.abs(f_cor_masked), dim=(1,2), keepdim=True) - margin # B X 1 X 1
    loss = torch.clamp(torch.div(off_diag_sum, num_remove_cov), min=0) # B X 1 X 1
    loss = torch.sum(loss) / B

    return loss


def get_covariance_matrix(f_map, eye=None):
    eps = 1e-5
    B, C, H, W = f_map.shape  # i-th feature size (B X C X H X W)
    HW = H * W
    if eye is None:
        eye = torch.eye(C).cuda()
    f_map = f_map.contiguous().view(B, C, -1)  # B X C X H X W > B X C X (H X W)
    f_cor = torch.bmm(f_map, f_map.transpose(1, 2)).div(HW-1) + (eps * eye)  # C X C / HW

    return f_cor, B
