from functools import partial

import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from importlib import import_module
from ltr.models.backbone.bert import *
from ltr.models.head.head import Corner_Predictor, Pyramid_Corner_Predictor
from ltr.utils.box_helper import *

class Self_Tracker(nn.Module):
    """ This is the Self_Tracker module that performs single object tracking """
    def __init__(self,
                 settings,
                 device=None,
                 test=False):
        """ Initializes the model.
        Parameters:
            num_classes: number of object classes, always 1 for single object tracking
        """
        super().__init__()
        self.device = device
        self.settings = settings

        self.backbone = build_backbone_vit(settings, test)
        self.head = build_head(settings, device)

        bert_config = BertConfig.from_json_file(settings.model.bert.config)
        self.bert_embeddings = BertEmbeddings(bert_config)
        self.bert = BertEncoder(bert_config)

    def forward(self, data):
        
        # data['template'] (B, NS, 3, 128, 128)  NS=2
        # data['search'] (B, 1, 3, 320, 320)

        template = data['template'][:,0,:,:,:]
        online_template = data['template'][:,1,:,:,:]
        search = data['search']
        nlp_ids = data['phrase_ids']
        nlp_attnmask = data['phrase_attnmask']

        if template.dim() == 5:
            template = template.squeeze(0)
        if online_template.dim() == 5:
            online_template = online_template.squeeze(0)
        if search.dim() == 5:
            search = search.squeeze(1)

        ext_attnmask = nlp_attnmask
        ext_attnmask = (1.0 - ext_attnmask.unsqueeze(1).unsqueeze(2)) * -10000.0
        embeddings = self.bert_embeddings(nlp_ids)
        last_hidden_state = self.bert(embeddings, ext_attnmask.to(dtype=embeddings.dtype))[-1]
        input_mask_expanded = nlp_attnmask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)  # 求和
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask  # 求平均

        nlp_token = mean_embeddings.unsqueeze(1)

        _, _, search, nl_tokens, reg_tokens = self.backbone(template, online_template, search, nlp_token)

        

        # run the MLP head
        b = search.size(0)
        if self.settings.model.head.name == "mlp":
            pred_boxes, prob_l, prob_t, prob_r, prob_b = self.head(reg_tokens, softmax=True)
            outputs_coord = pred_boxes.sigmoid()
            outputs_coord_new = outputs_coord.view(b, 1, 4)
            out = {
                'pred_boxes': outputs_coord_new,
                'prob_l': prob_l,
                'prob_t': prob_t,
                'prob_b': prob_b,
                'prob_r': prob_r,
            }
        elif self.settings.model.head.name == "pmlp":
            outputs_coord = box_xyxy_to_cxcywh(box_xywh_to_xyxy(self.head(reg_tokens).sigmoid()))
            outputs_coord_new = outputs_coord.view(b, 1, 4)    
            out = {'pred_boxes': outputs_coord_new} 

        elif "corner" in self.settings.model.head.name:
            outputs_coord = box_xyxy_to_cxcywh(self.head(search))
            outputs_coord_new = outputs_coord.view(b, 1, 4)    
            out = {'pred_boxes': outputs_coord_new}       
        else:
            raise ValueError("HEAD TYPE %s is not supported." % self.settings.model.head.name)
        
        return out

    def forward_test(self, template, online_template, search, nlp_ids=None, nlp_attnmask=None):
        # data['template'] (B, NS, 3, 128, 128)  NS=2
        # data['search'] (B, 1, 3, 320, 320)
        

        if template.dim() == 5:
            template = template.squeeze(0)
        if online_template.dim() == 5:
            online_template = online_template.squeeze(0)
        if search.dim() == 5:
            search = search.squeeze(1)


        ext_attnmask = nlp_attnmask
        ext_attnmask = (1.0 - ext_attnmask.unsqueeze(1).unsqueeze(2)) * -10000.0
        embeddings = self.bert_embeddings(nlp_ids)
        last_hidden_state = self.bert(embeddings, ext_attnmask.to(dtype=embeddings.dtype))[-1]
        input_mask_expanded = nlp_attnmask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)  # 求和
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask  # 求平均

        nlp_token = mean_embeddings.unsqueeze(1)

        _, _, search, nl_tokens, reg_tokens = self.backbone(template, online_template, search, nlp_token)

        

        # run the MLP head
        b = search.size(0)
        if self.settings.model.head.name == "mlp":
            pred_boxes, prob_l, prob_t, prob_r, prob_b = self.head(reg_tokens, softmax=True)
            outputs_coord = pred_boxes.sigmoid()
            outputs_coord_new = outputs_coord.view(b, 1, 4)
            out = {
                'pred_boxes': outputs_coord_new,
                'prob_l': prob_l,
                'prob_t': prob_t,
                'prob_b': prob_b,
                'prob_r': prob_r,
            }
        elif self.settings.model.head.name == "pmlp":
            outputs_coord = box_xyxy_to_cxcywh(box_xywh_to_xyxy(self.head(reg_tokens).sigmoid()))
            outputs_coord_new = outputs_coord.view(b, 1, 4)    
            out = {'pred_boxes': outputs_coord_new} 

        elif "corner" in self.settings.model.head.name:
            outputs_coord = box_xyxy_to_cxcywh(self.head(search))
            outputs_coord_new = outputs_coord.view(b, 1, 4)    
            out = {'pred_boxes': outputs_coord_new}       
        else:
            raise ValueError("HEAD TYPE %s is not supported." % self.settings.model.head.name)
        
        return out


def build_backbone_vit(settings, test=False):
    backbone_settings = settings.model.backbone

    backbone_class = import_module(f"ltr.models.backbone.backbone_{settings.model.version}")  
    if settings.model.head.name == "mlp":
        num_reg_tokens = 4
    elif settings.model.head.name == "pmlp":
        num_reg_tokens = 1
    else:
        num_reg_tokens = 0


    img_size_s = settings.train.search_size
    img_size_t = settings.train.template_size
    if backbone_settings.type == 'large_patch16':
        backbone = backbone_class.VisionTransformer(
            img_size_s=img_size_s, img_size_t=img_size_t,
            patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=0.1, num_reg_tokens=num_reg_tokens)
    elif backbone_settings.type == 'base_patch16':
        backbone = backbone_class.VisionTransformer(
            img_size_s=img_size_s, img_size_t=img_size_t,
            patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=0.1, num_reg_tokens=num_reg_tokens)
    else:
        raise KeyError(f"VIT_TYPE shoule set to 'large_patch16' or 'base_patch16'")

    if (not test) and settings.model.backbone.pretrain:
        try:
            ckpt_path = settings.model.backbone.pretrain
            ckpt = torch.load(ckpt_path, map_location='cpu')['model']
            # print(ckpt.keys())
            new_dict = {}
            for k, v in ckpt.items():
                if 'pos_embed' not in k and 'mask_token' not in k:    # use fixed pos embed
                    new_dict[k] = v
            ckpt = new_dict


            missing_keys, unexpected_keys = backbone.load_state_dict(ckpt, strict=False)
            real_miss = [x for x in missing_keys if 'pos_embed' not in x]
            try:
                assert len(real_miss)==0
            except Exception as e:
                print("missing keys", real_miss)
     
        except:
            print("Warning: Pretrained CVT weights are not loaded")

    return backbone



def build_head(settings, device):
    if settings.model.head.name == "mlp":
        feat_sz = 18
        stride = settings.train.search_size / feat_sz
        hidden_dim = settings.model.head.hidden_dim
        head_class = import_module(f"ltr.models.head.head")
        mlp_head = head_class.MlpHead(
            in_dim=hidden_dim,
            hidden_dim=hidden_dim,
            feat_sz=feat_sz,
            num_layers=2,
            stride=stride,
            norm=True
        )
        return mlp_head
        # hidden_dim = settings.model.head.hidden_dim
        # mlp_head = MLP(hidden_dim, hidden_dim, 4, 3)  # dim_in, dim_hidden, dim_out, 3 layers
        # return mlp_head
    elif settings.model.head.name == "pmlp":
        hidden_dim = settings.model.head.hidden_dim
        head_class = import_module(f"ltr.models.head.head")
        return head_class.MLP(hidden_dim, hidden_dim, 4, 6)
    elif "corner" in settings.model.head.name:
        channel = getattr(settings.model.head, "head_dim", 384)
        freeze_bn = getattr(settings.model.head, "head_freeze_bn", False)
        # print("head channel: %d" % channel)
        if settings.model.head.name == "corner":
            stride = 16
            feat_sz = int(settings.train.search_size / stride)
            corner_head = Corner_Predictor(inplanes=settings.model.head.hidden_dim, channel=channel,
                                           feat_sz=feat_sz, stride=stride, freeze_bn=freeze_bn, device=device)
        elif settings.model.head.name == "corner_up":
            stride = 4
            feat_sz = int(settings.train.search_size / stride)
            corner_head = Pyramid_Corner_Predictor(inplanes=settings.model.head.hidden_dim, channel=channel,
                                                   feat_sz=feat_sz, stride=stride, freeze_bn=freeze_bn, device=device)
        else:
            raise ValueError()
        return corner_head
    else:
        raise ValueError("HEAD TYPE %s is not supported." % settings.model.head.name)


# @model_constructor
def tracker_model(settings, device=None, test=False):
    model = Self_Tracker(
        settings = settings,
        device = device,
        test=test
    )
    return model
