import torch
import torch.nn.functional as F
from torch import nn, Tensor

import math
from typing import Callable, List, Any, Tuple, Dict

from util import box_ops
from util.misc import (NestedTensor, get_world_size, is_dist_avail_and_initialized)

from .decoder import build_vg_decoder
from pytorch_pretrained_bert.modeling import BertModel
import ssl
from torch.autograd import Function

ssl._create_default_https_context = ssl._create_unverified_context

class SigmoidGeometricMean(Function):
    """Forward and backward function of geometric mean of two sigmoid
    functions.

    This implementation with analytical gradient function substitutes
    the autograd function of (x.sigmoid() * y.sigmoid()).sqrt(). The
    original implementation incurs none during gradient backprapagation
    if both x and y are very small values.
    """

    @staticmethod
    def forward(ctx, x, y):
        x_sigmoid = x.sigmoid()
        y_sigmoid = y.sigmoid()
        z = (x_sigmoid * y_sigmoid).sqrt()

        ctx.save_for_backward(x_sigmoid, y_sigmoid, z)
        return z

    @staticmethod
    def backward(ctx, grad_output):
        x_sigmoid, y_sigmoid, z = ctx.saved_tensors
        grad_x = grad_output * z * (1 - x_sigmoid) / 2
        grad_y = grad_output * z * (1 - y_sigmoid) / 2
        return grad_x, grad_y


sigmoid_geometric_mean = SigmoidGeometricMean.apply

class Tob(nn.Module):
    def __init__(self, pretrained_weights, args=None):
        """ Initializes the model."""
        super().__init__()

        # Image feature encoder (CNN + Transformer encoder)
        self.backbone = torch.hub.load(
            "facebookresearch_dinov2_main/",
            'dinov2_vitb14', source='local')
        # self.backbone = torch.hub.load("/root/.cache/torch/hub/facebookresearch_dinov2_main/",
        #                                'dinov2_vitb14', source='local')
        self.extract_layer = [9,10, 11]
        self.backbone.mask_token.requires_grad_(False)

        self.input_proj = nn.Linear(768, 256)

        # Text feature encoder (BERT)
        self.bert = BertModel.from_pretrained(args.bert_model)
        self.bert_proj = nn.Linear(args.bert_output_dim, args.hidden_dim)
        self.bert_output_layers = args.bert_output_layers
        for v in self.bert.pooler.parameters():
            v.requires_grad_(False)

        # visual grounding
        self.trans_decoder = build_vg_decoder(args)

        hidden_dim = 256
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)

        self.merge_layer = 12
        self.conditional_layer = 6
        self.conditional_index = self.merge_layer - self.conditional_layer
        self.linear_word_layers = nn.ModuleList([
            nn.Linear(256, 128) for _ in range(self.conditional_layer)
        ])
        self.linear_x_layers = nn.ModuleList([
            nn.Linear(768, 128) for _ in range(self.conditional_layer)
        ])

        if pretrained_weights:
            self.load_pretrained_weights(pretrained_weights)


    def load_pretrained_weights(self, weights_path):
        def load_weights(module, prefix, weights):
            module_keys = module.state_dict().keys()
            weights_keys = [k for k in weights.keys() if prefix in k]
            update_weights = dict()
            for k in module_keys:
                prefix_k = prefix+'.'+k
                if prefix_k in weights_keys:
                    update_weights[k] = weights[prefix_k]
                else:
                    print(f"Weights of {k} are not pre-loaded.")
            module.load_state_dict(update_weights, strict=False)

        weights = torch.load(weights_path, map_location='cpu')['model']

    def dot_product_word_wise(self, visu_src, text_src, word_mask):  # b, hw, c       l, b, c
        # word_mask: [B, 40] 前N个为True，后面的40-N个为False
        text_src = text_src.permute(1, 0, 2)    # b, l, c
        B = word_mask.size(0)
        pairwise_weight = torch.bmm(visu_src, text_src.transpose(1, 2))  # B, HW, L
        pairwise_weight = torch.bmm(pairwise_weight, word_mask.unsqueeze(-1).to(dtype=torch.float32))
        # pairwise_weight: [B, HxW, N] -> [B, HxW, 1]
        pairwise_weight = pairwise_weight / word_mask.sum(1).view(B, 1, 1)
        # return pairwise_weight.sigmoid()
        return pairwise_weight

    def do_nothing(self, x, mode=None):
        return x

    def bipartite_soft_matching(
        self, 
        metric: torch.Tensor,
        r: int, 
        weight: torch.Tensor = None, 
        k: int = 0
    ) -> Tuple[Callable, Callable]:
        """
        Applies ToMe with a balanced matching set (50%, 50%).

        Input size is [batch, tokens, channels].
        r indicates the number of tokens to remove (max 50% of tokens).
        """
        # We can only reduce by a maximum of 50% tokens
        n, t, c = metric.shape
        r = min(r, t // 2)

        if r <= 0:
            return self.do_nothing

        with torch.no_grad():
            metric = metric / metric.norm(dim=-1, keepdim=True)
            a, b = metric[..., ::2, :], metric[..., 1::2, :]
            score = a @ b.transpose(-1, -2)
            if k == 0: 
                a_weight, b_weight = weight[..., ::2, :], weight[..., 1::2, :]
                a_weight_expanded = a_weight.repeat(1, 1, b_weight.shape[1])
                b_weight_expanded = b_weight.repeat(1, 1, a_weight.shape[1]).transpose(1, 2)

                score_weight = sigmoid_geometric_mean(a_weight_expanded, b_weight_expanded)

                alpha = 0.3
                scores = (1 - alpha) * score + alpha * (1 - score_weight)
            else: 
                scores = score

            node_max, node_idx = scores.max(dim=-1)
            edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

            unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
            src_idx = edge_idx[..., :r, :]  # Merged Tokens
            dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
            src, dst = x[..., ::2, :], x[..., 1::2, :]
            n, t1, c = src.shape
            unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
            src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
            dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
            return torch.cat([unm, dst], dim=1)

        return merge

    def merge_wavg(
        self,
        merge: Callable,  x: torch.Tensor, weight: torch.Tensor, size: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Applies the merge function by taking a weighted average based on token size.
        Returns the merged tensor and the new token sizes.
        """
        x = merge(x * size, mode="sum")
        weight = merge(weight * size, mode="sum")
        size = merge(size, mode="sum")
        x = x / size
        weight = weight / size
        x = x * weight * 2
        # x = x * (1 + weight)
        return x, size
    def merge_wavg_wow(
        self,
        merge: Callable,  x: torch.Tensor, size: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Applies the merge function by taking a weighted average based on token size.
        Returns the merged tensor and the new token sizes.
        """
        x = merge(x * size, mode="sum")
        size = merge(size, mode="sum")
        x = x / size
        return x, size

    def tome_block(self, i, word_feat, word_mask, block, x):
        if isinstance(x, tuple):
            x, size = x
        else:
            x = x        # torch.Size([8, 1025, 768])
            size = torch.ones_like(x[..., 0, None])
        x = block(x)
        if i >= self.conditional_index:
            word_feat_linear = self.linear_word_layers[i - self.conditional_index](word_feat)  # (40, b, 128)
            x_linear = self.linear_x_layers[i - self.conditional_index](x)  # (b, 1025, 128)
            weight = self.dot_product_word_wise(x_linear, word_feat_linear, ~word_mask)   # (b, 1025, 1)
            # weight = 1 - weight
            merge = self.bipartite_soft_matching(metric=x, r=64, weight=weight)
            x = self.merge_wavg(merge, x, weight.sigmoid(), size = size)
        return x


    def dino_forward(self, word_feat, word_mask, x, masks=None):
        x = self.backbone.prepare_tokens_with_masks(x, masks)

        ml_feature = []
        i = 0
        for blk in self.backbone.blocks:
            # x = blk(x)
            x = self.tome_block(i, word_feat, word_mask, blk, x)
            if isinstance(x, tuple):
                y, size = x
            else:
                y = x        # torch.Size([8, 1025, 768])
            if i in self.extract_layer:
                ml_feature.append(self.backbone.norm(y)[:, self.backbone.num_register_tokens + 1:])
            i = i + 1
        return torch.cat(ml_feature, dim=1)
    def forward(self, image, image_mask, word_id, word_mask):

        N = image.size(0)

        # Text features
        word_feat, _ = self.bert(word_id, token_type_ids=None, attention_mask=word_mask)
        word_feat = torch.stack(word_feat[-self.bert_output_layers:], 1).mean(1)
        word_feat = self.bert_proj(word_feat)
        word_feat = word_feat.permute(1, 0, 2) # NxLxC -> LxNxC
        word_mask = ~word_mask

        # Image features
        feature_dino = self.dino_forward(word_feat, word_mask, image)
        tr_input = self.input_proj(feature_dino)
        # Discriminative feature encoding + Multi-stage reasoning
        hs = self.trans_decoder(tr_input.permute(1, 0, 2), None, None, word_feat, word_mask)

        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {'pred_boxes': outputs_coord[-1]}
        if self.training:
            out['aux_outputs'] = [{'pred_boxes': b} for b in outputs_coord[:-1]]
        return out

class VGCriterion(nn.Module):
    """ This class computes the loss for VLTVG."""
    def __init__(self, weight_dict, loss_loc, box_xyxy):
        """ Create the criterion.
        Parameters:
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
        """
        super().__init__()
        self.weight_dict = weight_dict

        self.box_xyxy = box_xyxy

        self.loss_map = {'loss_boxes': self.loss_boxes}

        self.loss_loc = self.loss_map[loss_loc]

    def loss_boxes(self, outputs, target_boxes, num_pos):
        """Compute the losses related to the bounding boxes (the L1 regression loss and the GIoU loss)"""
        assert 'pred_boxes' in outputs
        src_boxes = outputs['pred_boxes'] # [B, #query, 4]
        target_boxes = target_boxes[:, None].expand_as(src_boxes)

        src_boxes = src_boxes.reshape(-1, 4) # [B*#query, 4]
        target_boxes = target_boxes.reshape(-1, 4) #[B*#query, 4]

        losses = {}
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
        losses['l1'] = loss_bbox.sum() / num_pos

        if not self.box_xyxy:
            src_boxes = box_ops.box_cxcywh_to_xyxy(src_boxes)
            target_boxes = box_ops.box_cxcywh_to_xyxy(target_boxes)
        loss_giou = 1 - box_ops.box_pair_giou(src_boxes, target_boxes)
        losses['giou'] = (loss_giou[:, None]).sum() / num_pos
        return losses


    def forward(self, outputs, targets):
        """ This performs the loss computation.
        """
        gt_boxes = targets['bbox']
        pred_boxes = outputs['pred_boxes']

        losses = {}
        B, Q, _ = pred_boxes.shape
        num_pos = avg_across_gpus(pred_boxes.new_tensor(B*Q))
        loss = self.loss_loc(outputs, gt_boxes, num_pos)
        losses.update(loss)

        # Apply the loss function to the outputs from all the stages
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                l_dict = self.loss_loc(aux_outputs, gt_boxes, num_pos)
                l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                losses.update(l_dict)

        return losses


class PostProcess(nn.Module):
    """ This module converts the model's output into the format we expect"""
    def __init__(self, box_xyxy=False):
        super().__init__()
        self.bbox_xyxy = box_xyxy

    @torch.no_grad()
    def forward(self, outputs, target_dict):
        """ Perform the computation"""
        rsz_sizes, ratios, orig_sizes = \
            target_dict['size'], target_dict['ratio'], target_dict['orig_size']
        dxdy = None if 'dxdy' not in target_dict else target_dict['dxdy']

        boxes = outputs['pred_boxes']

        assert len(boxes) == len(rsz_sizes)
        assert rsz_sizes.shape[1] == 2

        boxes = boxes.squeeze(1)

        # Convert to absolute coordinates in the original image
        if not self.bbox_xyxy:
            boxes = box_ops.box_cxcywh_to_xyxy(boxes)
        img_h, img_w = rsz_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes = boxes * scale_fct
        if dxdy is not None:
            boxes = boxes - torch.cat([dxdy, dxdy], dim=1)
        boxes = boxes.clamp(min=0)
        ratio_h, ratio_w = ratios.unbind(1)
        boxes = boxes / torch.stack([ratio_w, ratio_h, ratio_w, ratio_h], dim=1)
        if orig_sizes is not None:
            orig_h, orig_w = orig_sizes.unbind(1)
            boxes = torch.min(boxes, torch.stack([orig_w, orig_h, orig_w, orig_h], dim=1))

        return boxes


def avg_across_gpus(v, min=1):
    if is_dist_avail_and_initialized():
        torch.distributed.all_reduce(v)
    return torch.clamp(v.float() / get_world_size(), min=min).item()

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x




def build_vgmodel(args):
    device = torch.device(args.device)

    model = Tob(pretrained_weights=args.load_weights_path, args=args)

    weight_dict = {'loss_cls': 1, 'l1': args.bbox_loss_coef}
    weight_dict['giou'] = args.giou_loss_coef
    weight_dict.update(args.other_loss_coefs)
    if args.aux_loss:
        aux_weight_dict = {}
        for i in range(args.dec_layers - 1):
            aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)

    criterion = VGCriterion(weight_dict=weight_dict, loss_loc=args.loss_loc, box_xyxy=args.box_xyxy)
    criterion.to(device)

    postprocessor = PostProcess(args.box_xyxy)

    return model, criterion, postprocessor
