import torch
import clip
import torch.nn as nn

from PIL import Image
from PIL import ImageFile

import torch.nn.functional as F

import os
import h5py

from utils.misc import rescale_bboxes

import torchvision
import clip
from clip.model import build_model
from models.seg_layers_1_3_1 import MHAttentionMap, GA, MLP, CrossAttention, HA #HA2 as HA
from models.seg_Bridger_1_3_1 import Bridger_ViT_1 as Bridger_VL_1
from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
from scipy.optimize import linear_sum_assignment


def batch_dice_loss(inputs, targets):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)
    numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
    denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss
    

def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
            balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    hw = inputs.shape[1]

    prob = inputs.sigmoid()
    focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits(
        inputs, torch.ones_like(inputs), reduction="none"
    )
    focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits(
        inputs, torch.zeros_like(inputs), reduction="none"
    )
    if alpha >= 0:
        focal_pos = focal_pos * alpha
        focal_neg = focal_neg * (1 - alpha)

    loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum(
        "nc,mc->nm", focal_neg, (1 - targets)
    )

    return loss / hw
        
    



def matcher(predict_bbox, predict_mask, predict_cls, mask, bbox, cost_mask_coef, cost_class_coef, cost_dice_coef, cost_bbox_coef, cost_giou_coef):
    bs, num_queries = predict_cls.shape[:2]

    indices = []
    for b in range(bs):
        out_prob = predict_cls[b].softmax(-1)  # [num_queries, num_classes]
        out_mask = predict_mask[b] # [num_queries, H，W]
        out_bbox = predict_bbox[b]
        cost_class = -out_prob[:, :1] # N,1
        tgt_mask = mask[b]
        tgt_bbox = bbox[b]
        # Flatten spatial dimension
        out_mask = out_mask.flatten(1)  # [num_queries, H*W]
        tgt_mask = tgt_mask.flatten(1)  # [1, H*W]

        # Compute the focal loss between masks
        cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask)

        # Compute the dice loss betwen masks
        cost_dice = batch_dice_loss(out_mask, tgt_mask)

        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
        # Compute the giou cost betwen boxes
        convert_out_bbox = torchvision.ops.box_convert(out_bbox, in_fmt="cxcywh", out_fmt="xyxy")
        convert_tgt_bbox = torchvision.ops.box_convert(tgt_bbox, in_fmt="cxcywh", out_fmt="xyxy")
        cost_giou = -generalized_box_iou(convert_out_bbox, convert_tgt_bbox) # BxN, B
        
        # Final cost matrix
        C = (
            cost_mask_coef * cost_mask
            + cost_class_coef * cost_class
            + cost_dice_coef * cost_dice
            + cost_bbox_coef * cost_bbox
            + cost_giou_coef * cost_giou
        )
        C = C.reshape(num_queries, -1).detach().cpu()

        indices.append(linear_sum_assignment(C))
    
    return [
        (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
        for i, j in indices
    ]



class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x

class Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        bridger_stages = config['bridger_stages']
        aggregate_layers = config['aggregate_layers']
        num_reg = config['num_reg']
        vit_type = config['vit_type']
        if vit_type == 'vit_32':
            vit_path = '/ckpt/clip/ViT-B-32.pt'
        elif vit_type == 'vit_16':
            vit_path = '/ckpt/clip/ViT-B-16.pt'
        else:
            print('vit_type Error!')
            exit()
        backbone, _ = clip.load(vit_path, resolution=config['image_res'])
        self.backbone = backbone.cuda().float()
        ladder_dim, nhead = 128, 8
        self.bridger = Bridger_VL_1(d_model=ladder_dim, nhead=nhead, num_reg=num_reg)
        fpn_in = [768, 768, 512]
        fpn_out = [256, 512, 1024]
        stride = [1, 1, 1]
        self.neck = HA(in_channels=fpn_in, out_channels=fpn_out, stride=stride)
        num_layers = 2
        vis_dim = 512
        num_head = 8
        dim_ffn= 512
        dropout = 0.1
        intermediate = False
        self.decoder = GA(num_layers=num_layers,
                                      d_model=vis_dim,
                                      nhead=num_head,
                                      dim_ffn=dim_ffn,
                                      dropout=dropout,
                                      return_intermediate=intermediate)
        # Projector
        self.bbox_embed = MLP(dim_ffn, dim_ffn, 4, 3)
        self.cls_embed = MLP(dim_ffn, dim_ffn // 8, 2, 3)
        self.cost_bbox_coef = 5.
        self.cost_giou_coef = 2.
        self.cost_class_coef = 1.
        self.cost_dice_coef = 0.5
        self.cost_mask_coef = 2.5
        self.reg_proj = nn.Linear(ladder_dim, dim_ffn)
        self.tgt_proj = nn.Linear(ladder_dim, dim_ffn)
        

        bias = False
        self.fusion = nn.Sequential(
            nn.Conv2d(dim_ffn, dim_ffn // 2, 3, padding=1, bias=bias),
            nn.BatchNorm2d(dim_ffn // 2),
            nn.ReLU(),
            nn.Conv2d(dim_ffn // 2, dim_ffn, 3, padding=1, bias=bias),
            nn.BatchNorm2d(dim_ffn),
            nn.ReLU(),
        )

    
        self.output_upscaling = nn.Sequential(
            nn.ConvTranspose2d(dim_ffn, dim_ffn // 4, kernel_size=2, stride=2),
            LayerNorm2d(dim_ffn // 4),
            nn.GELU(),
            nn.ConvTranspose2d(dim_ffn // 4, dim_ffn // 8, kernel_size=2, stride=2),
            nn.GELU(),
        )
        self.mask_mlps = MLP(dim_ffn, dim_ffn, dim_ffn // 8, 3)
        
        self.cost_mask = 10.0
        self.cost_class = 1.0
        self.cost_dice = 1.0
    
  
    def dice_loss(self, inputs, targets, num_masks):
        """
        Compute the DICE loss, similar to generalized IOU for masks
        Args:
            inputs: A float tensor of arbitrary shape.
                    The predictions for each example.
            targets: A float tensor with the same shape as inputs. Stores the binary
                    classification label for each element in inputs
                    (0 for the negative class and 1 for the positive class).
        """
        inputs = inputs.sigmoid()
        # inputs = inputs.flatten(1)
        numerator = 2 * (inputs * targets).sum(-1)
        denominator = inputs.sum(-1) + targets.sum(-1)
        loss = 1 - (numerator + 1) / (denominator + 1)
        return loss.sum() / num_masks

    def set_aux_criterion(self, output):
        loss_dict = {}
        bbox, reg_tokens, mask, pred_mask_ls, img_shape = output['bbox'], output['reg_tokens'], output['mask'], output['pred_mask_ls'], output['img_shape']
        loss_aux = torch.zeros(1).to(mask.device)
        for i, rt in enumerate(reg_tokens):
            batch_size = len(mask)

            predict_xy = torch.sigmoid(self.bridger.aux_heads[i](rt)).transpose(0,1) # B,N,4
            pred_mask = pred_mask_ls[i] #torch.Size([8, 3, 20, 20])  mask: torch.Size([8, 1, 20, 20])
            if pred_mask.shape[-2:] != img_shape:
                pred_mask = F.interpolate(pred_mask, img_shape, mode='bilinear', align_corners=True)
            pred_cls = self.bridger.aux_cls[i](rt).transpose(0,1) # B,N,2
            
            indices = matcher(predict_xy, pred_mask, pred_cls, mask, bbox, 
                cost_mask_coef=self.cost_mask_coef, cost_class_coef=self.cost_class_coef, cost_dice_coef=self.cost_dice_coef,
                cost_giou_coef=self.cost_giou_coef, cost_bbox_coef=self.cost_bbox_coef)

            pre_index = torch.tensor([t[0].item() for t in indices]).cuda()
            
            cls_tgt = torch.ones(pred_cls.shape[:2], dtype=torch.long).cuda() #B,N
            cls_tgt[torch.arange(batch_size), pre_index] -= 1
            
            pred_cls = pred_cls.reshape(-1, 2)
            cls_tgt = cls_tgt.flatten()
            loss_ce = F.cross_entropy(pred_cls, cls_tgt)

            pred_mask = pred_mask[torch.arange(batch_size), pre_index].flatten(1)
            _mask = mask.squeeze(1).flatten(1)

            predict_xy = predict_xy[torch.arange(batch_size), pre_index]
            _bbox = bbox.squeeze(1)
            
            loss_aux_l1 = F.l1_loss(predict_xy, _bbox, reduction='none').sum() / predict_xy.size(0) #* self.cost_bbox
            
            convert_bbox = torchvision.ops.box_convert(_bbox, in_fmt="cxcywh", out_fmt="xyxy")
            convert_predict_xy = torchvision.ops.box_convert(predict_xy, in_fmt="cxcywh", out_fmt="xyxy")
            loss_aux_giou = (1 - torch.diag(generalized_box_iou(convert_predict_xy, convert_bbox))).sum() / predict_xy.size(0) #* self.cost_giou
            
            loss_aux += F.binary_cross_entropy_with_logits(pred_mask, _mask)
            loss_aux += self.dice_loss(pred_mask, _mask, 1)
            loss_aux += loss_ce
            loss_aux += loss_aux_giou
            loss_aux += loss_aux_l1

        return loss_aux 

    def set_criterion(self, output):
        """
            Compute the losses related to the bounding boxes, 
            including the L1 regression loss and the GIoU loss
            targets, pred_box: cxcywh
        """
        # B,N,H,W  and B,N,2 and B,4
        pred_box, pred_mask, pred_cls, mask, bbox = output['pred_box'], output['pred_mask'], output['pred_cls'], output['mask'], output['bbox']     
        pred_box = pred_box.transpose(0,1)

        batch_size = pred_mask.shape[0]
        num_boxes = pred_mask.size(1)
        
        indices = matcher(pred_box, pred_mask, pred_cls, mask, bbox, 
            cost_mask_coef=self.cost_mask_coef, cost_class_coef=self.cost_class_coef, cost_dice_coef=self.cost_dice_coef,
            cost_giou_coef=self.cost_giou_coef, cost_bbox_coef=self.cost_bbox_coef)
        
        pre_index = torch.tensor([t[0].item() for t in indices]).cuda()
        
        cls_tgt = torch.ones(pred_cls.shape[:2], dtype=torch.long).cuda() #B,N
        cls_tgt[torch.arange(batch_size), pre_index] -= 1
        
        pred_cls = pred_cls.reshape(-1, 2)
        cls_tgt = cls_tgt.flatten()
        loss_ce = F.cross_entropy(pred_cls, cls_tgt)

        # box
        pred_box = pred_box[torch.arange(batch_size), pre_index]
        bbox = bbox.squeeze(1)
        loss_bbox = F.l1_loss(pred_box, bbox, reduction='none')
        convert_bbox = torchvision.ops.box_convert(bbox, in_fmt="cxcywh", out_fmt="xyxy")
        convert_pred_box = torchvision.ops.box_convert(pred_box, in_fmt="cxcywh", out_fmt="xyxy")
        loss_giou = 1 - torch.diag(generalized_box_iou(convert_pred_box, convert_bbox))

        losses = {}
        pred_mask = pred_mask[torch.arange(batch_size), pre_index].flatten(1)
        mask = mask.squeeze(1).flatten(1)

        losses['loss_mask'] = F.binary_cross_entropy_with_logits(pred_mask, mask) * self.cost_mask_coef
        losses['loss_dice'] = self.dice_loss(pred_mask, mask, 1) * self.cost_dice_coef     
        losses['loss_ce'] = loss_ce * self.cost_class_coef
        losses['loss_bbox'] = loss_bbox.sum() / batch_size * self.cost_bbox_coef  
        losses['loss_giou'] = loss_giou.sum() / batch_size * self.cost_giou_coef  
 
        losses['loss_aux'] = self.set_aux_criterion(output)* 0.1 #self.set_aux_criterion(output)* 0.1 #torch.tensor(0.).cuda().detach() #self.set_aux_criterion(output)
        # losses['loss_giou'] = losses['loss_bbox'] = torch.tensor(0.).cuda().detach()
        return losses
    

    def forward(self, image, text_ids, mask=None, bbox=None, idx=None, epoch=None, training=False):
        '''
            vis:list
                torch.Size([32, 768, 16, 16])
                torch.Size([32, 768, 16, 16])
                torch.Size([32, 512, 16, 16])
        
        '''     
        # padding mask used in decoder
        pad_mask = torch.zeros_like(text_ids).masked_fill_(text_ids == 0, 1).bool()
        
        vis, word, state, reg_tokens, feat_ls, pred_mask_ls, attn_ls = self.bridger(image, text_ids, self.backbone)
        
        # tgt = self.tgt_proj(reg_tokens[-1])
        bs = state.size(0)
        reg = self.reg_proj(reg_tokens[-1]) 
        tgt = torch.zeros(reg.shape).cuda()

        fq = self.neck(vis, state)
        
        # output  reg_token: torch.Size([3, bs, 512]) vis: B,C,H,W
        reg_token, vis, attn = self.decoder(reg+tgt, fq, word, pad_mask=pad_mask, query_pos=None, return_attn=True) 

        # vis += fq
        _vis = (vis + fq) * word[:,0].unsqueeze(-1).unsqueeze(-1)
        vis = self.fusion(_vis)

        vis = self.output_upscaling(vis).flatten(-2)
        mask_token = self.mask_mlps(reg_token)
        pred_mask =  F.normalize(mask_token.transpose(0,1), dim=-1, p=2) @ F.normalize(vis, dim=1, p=2)
        B, N_q, N = pred_mask.shape
        H = int(N ** 0.5)
        W = N // H
        pred_mask = pred_mask.reshape(B, N_q, H, W)
        if pred_mask.shape[-2:] != image.shape[-2:]:
            pred_mask = F.interpolate(pred_mask, image.shape[-2:], mode='bilinear', align_corners=True)

        pred_cls = self.cls_embed(reg_token)
        pred_box = self.bbox_embed(reg_token).sigmoid() # torch.Size([32, 4])

        
        # reg_token = reg_tokens[-1]
        # pred_box = self.bridger.aux_heads[-1](reg_token.squeeze(0)).sigmoid()  # B,4 
        pred_cls = pred_cls.transpose(0,1)
        output = {'pred_box': pred_box, 'pred_mask': pred_mask, 'pred_cls': pred_cls, 'img_shape': image.shape[-2:]}
        
        if training:
            if pred_mask.shape[-2:] != mask.shape[-2:]:
                mask = F.interpolate(mask.unsqueeze(1), pred_mask.shape[-2:],
                                     mode='nearest').detach()   
            else:
                mask = mask.unsqueeze(1).float().detach()    
            output.update(dict(reg_tokens=reg_tokens))
            output.update(dict(pred_mask_ls=pred_mask_ls))
            output.update(dict(mask=mask))
            output.update(dict(bbox=bbox))
 
            losses = self.set_criterion(output)
            return losses
        else:
            index = torch.argmax(pred_cls.softmax(-1)[:,:,0], dim=-1)
            pred_mask = pred_mask[torch.arange(len(index)), index]
            output['pred_mask'] = pred_mask

            pred_box = pred_box.transpose(0,1)
            pred_box = pred_box[torch.arange(len(index)), index]
            output['pred_box'] = pred_box
 
        return output





