# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
"""
import torch
import torch.nn.functional as F
from torch import nn
# from ..backbone import build_backbone, Backbone
# from ..body.encoder import build_encoder
# from ..body.decoder import build_decoder

from detectron2.modeling import  build_backbone

from .pixel_decoder.maskdino_encoder import build_pixel_decoder
from .transformer_decoder.maskdino_decoder import build_transformer_decoder

import random
from transformers import AutoTokenizer
from collections import OrderedDict
from ..modules.point_features import point_sample
from timm.models.layers import trunc_normal_
from transformers import CLIPTokenizer,CLIPTextModel
from .vos_utils import masks_to_boxes, FeatureFuser
import numpy as np 
import math


def rand_sample(x, max_len):
    if x.shape[1] <= max_len:
        return x
    else:
        rand_idx = torch.randperm(x.shape[1])[:max_len]
        return x[:,rand_idx]


def agg_lang_feat(features, mask, pool_type="average"):
    """average pooling of language features"""
    # feat: (bs, seq_len, C)
    # mask: (bs, seq_len)
    if pool_type == "average":
        embedded = features * mask.unsqueeze(-1).float() # use mask to zero out invalid token features
        aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
    elif pool_type == "max":
        out = []
        for i in range(len(features)):
            pool_feat, _ = torch.max(features[i][mask[i]], 0) # (L, C) -> (C, )
            out.append(pool_feat)
        aggregate = torch.stack(out, dim=0) # (bs, C)
    else:
        raise ValueError("pool_type should be average or max")
    return aggregate

class GLEE_Model(nn.Module):
    """
    Main class for mask classification semantic segmentation architectures.
    """
    def __init__(self, cfg, matcher, device, video_info, contras_mean):
        super().__init__()
        self.cfg = cfg
        self.matcher = matcher
        self.backbone = build_backbone(cfg)
        output_channels = [v for k,v in self.backbone._out_feature_channels.items()]
        self.sot_fuser = FeatureFuser(output_channels[-3:], 256)
        
       
        self.tokenizer = CLIPTokenizer.from_pretrained('GLEE/clip_vit_base_patch32') 
        self.tokenizer.add_special_tokens({'cls_token': self.tokenizer.eos_token})
        self.text_encoder = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
        # self.text_encoder_teacher = CLIPTextModel.from_pretrained('GLEE/clip_vit_base_patch32')
        self.lang_encoder = None
        # for p in self.text_encoder_teacher.parameters():
            # p.requires_grad = False
        self.lang_projection = nn.Parameter(torch.rand(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, cfg.MODEL.DIM_PROJ))
        self.text_encode_type = 'clip_teacher'
        
        # self.lang_encoder = None     
        self.pixel_decoder = build_pixel_decoder(cfg, self.backbone.output_shape())
        transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
        self.predictor = build_transformer_decoder(cfg, transformer_predictor_in_channels, lang_encoder = self.lang_encoder, mask_classification=True,)
        self.to(device)
        
        self.video_info = video_info
        self.contras_mean = contras_mean

        self.track_loss_version = cfg.MODEL.TRACK_VERSION

        self.no_mask_tasks = ['obj365', 'obj365_clip','openimage', 'openimage_clip', 'vg', 'grit', 'bdd_det', 'bdd_track_box'] 


        # for visual prompt
        hidden_dim = 256
        self.max_spatial_len = [512,512,512,512]
        self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(4)])
        trunc_normal_(self.mask_sptial_embed[0], std=.02)
        trunc_normal_(self.mask_sptial_embed[1], std=.02)
        trunc_normal_(self.mask_sptial_embed[2], std=.02)
        trunc_normal_(self.mask_sptial_embed[3], std=.02)
        # learnable positive negative indicator
        self.pn_indicator = nn.Embedding(2, hidden_dim)

    @property
    def device(self):
        return self.pixel_mean.device
    
    def forward(self, images, prompts, task, targets=None, batch_name_list=None, is_train = True, visual_prompt_type='scribble'):
        extra =  {}
        # dist_loss = None
        early_semantic = None

        if self.text_encode_type == "clip_teacher":
            if task not in ['grounding','rvos']:
                assert batch_name_list
                calsses_name_list = batch_name_list
                tokenized = self.tokenizer.batch_encode_plus(calsses_name_list,
                        max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, # 256
                        padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", # max_length
                        return_special_tokens_mask=True,
                        return_tensors='pt',
                        truncation=True).to(images.device)
                texts = (tokenized['input_ids'], tokenized['attention_mask'])
                token_x = self.text_encoder(*texts)['last_hidden_state']

                valid_mask = tokenized['attention_mask'].bool()
                # token_x_teacher = self.text_encoder_teacher(*texts)['last_hidden_state']
                # if is_train:
                # dist_loss =  F.mse_loss(token_x[valid_mask], token_x_teacher[valid_mask] )
                    # F.l2_loss(token_x[valid_mask], token_x_teacher[valid_mask] )  
                token_x = token_x @ self.lang_projection
                lang_feat_pool = agg_lang_feat(token_x, tokenized['attention_mask'], pool_type="average")  # (bs,  768)
                extra['class_embeddings'] = lang_feat_pool 
                if True: # early_fusion
                    gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
                    gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
                    gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0  #[bs,L]
                    early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask} 

        
        if 'grounding' in prompts:
 
            if self.text_encode_type == 'clip_frozen' or self.text_encode_type == 'clip_teacher':

                tokens = self.tokenizer(
                    prompts['grounding'], padding='max_length', truncation=True, max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, return_tensors='pt'
                    )
                tokens = {key: value.to(images.device) for key, value in tokens.items()}

                texts = (tokens['input_ids'], tokens['attention_mask'])
                x = self.text_encoder(*texts)
                token_x = x['last_hidden_state']
                token_x = token_x @ self.lang_projection

                extra['grounding_tokens'] = token_x.permute(1,0,2) #[len,bz,C]

                non_zero_query_mask = tokens['attention_mask']
                lang_feat_pool = agg_lang_feat(token_x, non_zero_query_mask, pool_type="average").unsqueeze(1) # (bs, 1, 768)

                dist_loss =  (lang_feat_pool*0).sum()
                
                extra['grounding_nonzero_mask'] = ~non_zero_query_mask.bool()  # [bz,len]
                extra['grounding_class'] = lang_feat_pool.squeeze(1) #[bz,C
                # gather_all_classtoken = token_x.flatten(0,1)[tokenized['attention_mask'].flatten(0,1)>0]
                # gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
                # gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0  #[bs,L]
                # early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask} 
                early_semantic = {"hidden":token_x.float(),"masks":tokens['attention_mask']>0} 
        

        if isinstance(images,torch.Tensor):
            features = self.backbone(images)
        else:
            features = self.backbone(images.tensor)
        



        if 'spatial' in prompts:
            ## setp 1,2,3
            key_images = [ images ]  #bz*[1,3,H,W]
            key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W]

            prompt_mode = visual_prompt_type            
            ref_feats, ref_masks = self.get_template(key_images, key_promptmasks, prompt_mode) 
            early_fusion = {"hidden":ref_feats,"masks":ref_masks} 
            if early_semantic is None:
                early_semantic = early_fusion
            else:
                early_semantic["hidden"] = torch.cat([early_semantic["hidden"],early_fusion["hidden"]],dim=1)
                early_semantic["masks"] = torch.cat([early_semantic["masks"],early_fusion["masks"]],dim=1)

        
        # bz = len(images)//2
        mask_features, _, multi_scale_features, zero_loss = self.pixel_decoder.forward_features(features, masks=None, early_fusion = early_semantic)

        if 'spatial' in prompts:
            pos_masks = prompts['spatial']
            # neg_masks = [~p for p in prompts['spatial']]
            neg_masks = [p&False for p in prompts['spatial']]
            
            extra.update({'spatial_query_pos_mask': pos_masks, 'spatial_query_neg_mask': neg_masks})


            _,h,w = extra['spatial_query_pos_mask'][0].shape
            divisor = torch.tensor([h,w], device=mask_features.device)[None,]
            # Get mean pos spatial query
            non_zero_pos_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']]
            non_zero_pos_point = nn.utils.rnn.pad_sequence(non_zero_pos_point, padding_value=-1).permute(1,0,2)  
            non_zero_pos_mask = (non_zero_pos_point.sum(dim=-1) < 0)  
            spatial_query_pos = point_sample(mask_features, non_zero_pos_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True) #[(N, C, P)
            spatial_query_pos = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_pos.transpose(1,2), ~non_zero_pos_mask)]).transpose(0,1).nan_to_num() # [1,bz,C]
            # Get mean neg spatial query
            non_zero_neg_point = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_neg_mask']]
            non_zero_neg_point = nn.utils.rnn.pad_sequence(non_zero_neg_point, padding_value=-1).permute(1,0,2)
            non_zero_neg_mask = (non_zero_neg_point.sum(dim=-1) < 0)
            spatial_query_neg = point_sample(mask_features, non_zero_neg_point.flip(dims=(2,)).type(mask_features.dtype), align_corners=True)
            spatial_query_neg = torch.stack([x[m].mean(dim=0, keepdim=True) for x, m in zip(spatial_query_neg.transpose(1,2), ~non_zero_neg_mask)]).transpose(0,1).nan_to_num()

            # Get layerwise spatial query
            src_spatial_queries = []
            src_spatial_maskings = []
            for i in range(len(multi_scale_features)):
                bs,dc,h,w = multi_scale_features[i].shape
                # src_mask_features = multi_scale_features[i].view(h,w,bs,dc)
                src_mask_features = multi_scale_features[i].permute(2,3,0,1)
                src_mask_features = src_mask_features @ self.mask_sptial_embed[i]

                non_zero_query_point_pos = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_pos_mask']]
                non_zero_query_point_neg = [rand_sample((m.nonzero()[:,1:]/divisor).t(), self.max_spatial_len[i]).t() for m in extra['spatial_query_neg_mask']]
                non_zero_query_point = [torch.cat([x,y], dim=0) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
                pos_neg_indicator = [torch.cat([torch.ones(x.shape[0], device=x.device), -torch.ones(y.shape[0], device=y.device)]) for x,y in zip(non_zero_query_point_pos, non_zero_query_point_neg)]
                pos_neg_indicator = nn.utils.rnn.pad_sequence(pos_neg_indicator, padding_value=0)
                non_zero_query_point = nn.utils.rnn.pad_sequence(non_zero_query_point, padding_value=-1).permute(1,0,2)
                non_zero_query_mask = (non_zero_query_point.sum(dim=-1) < 0)
                non_zero_query_point[non_zero_query_mask] = 0

                spatial_tokens = point_sample(src_mask_features.permute(2,3,0,1), non_zero_query_point.flip(dims=(2,)).type(src_mask_features.dtype), align_corners=True).permute(2,0,1)
                spatial_tokens[pos_neg_indicator==1] += self.pn_indicator.weight[0:1]
                spatial_tokens[pos_neg_indicator==-1] += self.pn_indicator.weight[1:2]

                src_spatial_queries += [spatial_tokens]
                src_spatial_maskings += [non_zero_query_mask]

            extra['visual_prompt_tokens'] = src_spatial_queries #[len,bz,C]
            extra['visual_prompt_nonzero_mask'] = src_spatial_maskings  # [bz,len]
        
        
        outputs = self.predictor(multi_scale_features, mask_features, extra=extra, task=task, masks=None, targets=targets)
        return  outputs 
 



     

    def get_template(self, imgs, pad_masks, prompt_mode='scribble'):
        """img: (N, 3, H, W), mask: (N, 1, H, W), bbox: (1, 4)"""
        """get 4-channel template"""

        croped_img_with_mask = []

        for image_i, mask_i in zip( imgs, pad_masks):

            if prompt_mode in ['scribble','point']:
                image_with_mask = image_i + mask_i.to(image_i)
            else:
                image_with_mask = image_i 

            # image_with_mask = torch.cat([image_i,mask_i.to(image_i)],dim=1) #[1,3,H,W]
            box_i = masks_to_boxes(mask_i[0])  #[xyxy]
            box_i[:, 2:] = box_i[:, 2:] - box_i[:, :2] #xywh
            

            x, y, w, h = box_i[0].long().tolist()

            self.search_area_factor=2

            crop_sz = math.ceil(math.sqrt(w * h) * self.search_area_factor)
            x1 = max(0,round(x + 0.5 * w - crop_sz * 0.5))
            x2 = x1 + crop_sz
            y1 = max(0,round(y + 0.5 * h - crop_sz * 0.5))
            y2 = y1 + crop_sz

            im_crop = image_with_mask[:, :, y1:y2, x1:x2]
            # resize
            if im_crop.shape[-1] ==0 or im_crop.shape[-2] ==0 :
                im_crop = image_with_mask
            im_crop = F.interpolate(im_crop, (256,256), mode='bilinear', align_corners=False)
            croped_img_with_mask.append(im_crop)
        croped_img_with_mask = torch.cat(croped_img_with_mask,dim=0) #[bz,3,256,256]
        with torch.no_grad():
            ref_srcs = self.backbone(croped_img_with_mask.contiguous())
        ref_srcs = [v for k,v in ref_srcs.items()]
        ref_feats = self.sot_fuser(ref_srcs[1:]).float() #[bz,256,32,32]

        ref_feats = ref_feats.flatten(-2).permute(0, 2, 1) # (bs, L, C)
        ref_masks = torch.ones_like(ref_feats[:,:,0])>0  #[bs,L]
        
        return ref_feats, ref_masks

