import math
import os
from typing import List
from torch.nn import functional as F
import torch
from torch import nn
from torch.nn.modules.transformer import _get_clones

from einops import rearrange
from lib.models.layers.nlp_embedding import nlp_embedding

from lib.models.layers.head import build_box_head
from lib.models.atstrack.vit import vit_base_patch16_224, vit_large_patch16_224
from lib.models.atstrack.vit_ce import vit_large_patch16_224_ce, vit_base_patch16_224_ce
from lib.utils.box_ops import box_xyxy_to_cxcywh


class atstrack(nn.Module):
    """ This is the base class for MMTrack """

    def __init__(self, transformer, box_head, aux_loss=False, head_type="CORNER", token_len=2):
        """ Initializes the model.
        Parameters:
            transformer: torch module of the transformer architecture.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.backbone = transformer
        self.box_head = box_head

        self.nlp_embedding = nlp_embedding()

        self.aux_loss = aux_loss
        self.head_type = head_type
        if head_type == "CORNER" or head_type == "CENTER":
            self.feat_sz_s = int(box_head.feat_sz)
            self.feat_len_s = int(box_head.feat_sz ** 2)

        if self.aux_loss:
            self.box_head = _get_clones(self.box_head, 6)

        # track query: save the history information of the previous frame
        self.track_query = None
        self.token_len = token_len
        
        self.norm=nn.LayerNorm(512)

        self.fusion_fc1 = nn.Linear(512, 512)
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.pool2 = nn.AvgPool2d(kernel_size=3, stride=3)
        self.pool3 = nn.AvgPool2d(kernel_size=4, stride=4)
        # self.fusion_fc3 = nn.Linear(512, 768)
        # self.fusion_fc4 = nn.Linear(512, 768)

        self.fcq1=nn.Linear(768, 768)

        self.fc1=nn.Linear(24,36)

        self.fusion_visual_textual = nn.MultiheadAttention(
            embed_dim=768,
            num_heads=1,
            dropout=0,
            batch_first=True
        )
        self.fusion_visual_class = nn.MultiheadAttention(
            embed_dim=768,
            num_heads=1,
            dropout=0,
            batch_first=True
        )
        self.fusion_visual_color = nn.MultiheadAttention(
            embed_dim=768,
            num_heads=1,
            dropout=0,
            batch_first=True
        )
        self.fusion_visual_action = nn.MultiheadAttention(
            embed_dim=768,
            num_heads=1,
            dropout=0,
            batch_first=True
        )

    def forward(self, template: torch.Tensor,
                search: torch.Tensor,
                attr_feat=None,
                query=None,
                attr=None,
                ce_template_mask=None,
                ce_keep_rate=None,
                nlp=None,
                return_last_attn=False,
                ):
        assert isinstance(search, list), "The type of search is not List"

        out_dict = []
        if self.training is True:
            bs = len(nlp)

        else:
            bs = 1

        if self.track_query is None:
            if query is None:
                query = self.fusion_fc1(self.nlp_embedding(nlp))
            self.track_query = torch.cat([query.view(bs, 1, -1, 768).mean(dim=2),torch.zeros(bs, 1, 768).cuda()],1)


        if attr_feat == None:
            class_des = attr['class']
            color_des = attr['color']
            action_des = attr['action']
            location_des = attr['location']
            class_emb = self.nlp_embedding(class_des)
            color_emb = self.nlp_embedding(color_des)
            action_emb = self.nlp_embedding(action_des)
            location_emb = self.nlp_embedding(location_des)
            attr_feat = [class_emb, color_emb, action_emb, location_emb]

        for i in range(len(search)):
            x, aux_dict = self.backbone(z=template.copy(), x=search[i],
                                        temporal_query=self.track_query,
                                        token_len=self.token_len)
            ######ViT########
            '''x, aux_dict = self.backbone(z=template.copy(), x=search[i],
                                        ce_template_mask=ce_template_mask, ce_keep_rate=ce_keep_rate,
                                        return_last_attn=return_last_attn, track_query=self.track_query,
                                        token_len=self.token_len)'''
            feat_last = x

            if isinstance(x, list):
                feat_last = x[-1]

            # temp_feat = feat_last[:, self.token_len:-self.feat_len_s]
            temp_feat = feat_last[:, self.token_len:-self.feat_len_s]
            enc_opt = feat_last[:, -self.feat_len_s:]  # encoder output for the search region (B, HW, C)

            att = torch.matmul(enc_opt, x[:, :1].transpose(1, 2))  # (B, HW, N)
            opt = (enc_opt.unsqueeze(-1) * att.unsqueeze(-2)).contiguous()
            opt = self.norm(opt.squeeze(-1)).unsqueeze(-1).permute((0, 3, 2, 1))

            # Forward head
            out, attr_query = self.forward_head(opt, temp_feat, attr_feat, attr, None)

            if self.backbone.add_cls_token:
                track_query = (x[:, :self.token_len-1].clone()).detach() # stop grad  (B, N, C)
                self.track_query = torch.cat([attr_query,track_query], 1)


            out.update(aux_dict)
            out['backbone_feat'] = x

            out_dict.append(out)

        return out_dict

    def forward_head(self, opt, temp, attr_feat, attr, gt_score_map=None):
        """
        enc_opt: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C)
        """
        # opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous()

        late_t = temp[:, -144:]

        class_emb = attr_feat[0]
        color_emb = attr_feat[1]
        action_emb = attr_feat[2]
        location_emb = attr_feat[3]
        bs, Nq, C, HW = opt.size()

        opt_feat_class = torch.Tensor().cuda()
        attr_feat = torch.Tensor().cuda()
        for i in range(bs):
            opt_feat = rearrange(opt[i], 't c l -> t l c')

            class_feat = class_emb[i].unsqueeze(0)
            class_feat = self.fusion_fc1(class_feat)

            if attr['color'][i] != 'None':
                vis_feat = rearrange(late_t[i].unsqueeze(0), 't l c -> t c l')
                vis_feat = self.pool1(vis_feat.view(1, -1, 12, 12)).squeeze(0)
                vis_feat = rearrange(vis_feat.view(1, 768, -1), 't c l -> t l c')

                class_feat = rearrange(class_feat, 't l c -> t c l')

                msim_c = torch.matmul(vis_feat, class_feat) / 28
                msim_c = F.softmax(msim_c, dim=-1)

                class_feat_v = F.softmax(class_feat, dim=-1)
                class_feat_v = rearrange(class_feat_v, 't c l -> t l c')
                class_map = torch.matmul(msim_c, class_feat_v) / 28

                x_min = class_map.min(dim=1, keepdim=True)[0]
                x_max = class_map.max(dim=1, keepdim=True)[0]

                class_map = (class_map - x_min) / (x_max - x_min + 1e-8)

                class_map=class_map.mean(dim=-1)
                a, b = 0.2, 1.0
                c_min, c_max = class_map.min(), class_map.max()

                class_map = (b - a) * (class_map - c_min) / (c_max - c_min + 1e-8) + a
                # fusion
                fused_feat = class_map.unsqueeze(2) * vis_feat
                class_feat = rearrange(class_feat, 't c l -> t l c')
                color_feat = color_emb[i].unsqueeze(0)
                color_feat = self.fusion_fc1(color_feat)

                fused_color = self.fusion_visual_color(
                    query=fused_feat,
                    key=color_feat,
                    value=color_feat,
                )[0]
            else:
                fused_color = class_feat
            if attr['action'][i] != 'None':
                hs_feat = rearrange(temp[i].unsqueeze(0), 't l c -> t c l')
                hs_feat = self.pool2(hs_feat.view(1, 768, 12, -1)).squeeze(0)
                hs_feat = rearrange(hs_feat.view(1, 768, -1), 't c l -> t l c')
                act_feat = action_emb[i].unsqueeze(0)
                act_feat = self.fusion_fc1(act_feat)

                fused_act = self.fusion_visual_action(
                    query=hs_feat,
                    key=act_feat,
                    value=act_feat,
                )[0]
            else:
                fused_act = class_feat

            if attr['location'][i] != 'None':

                search_feat = opt[i].unsqueeze(0)
                search_feat = self.pool3(search_feat.view(1, -1, 24, 24)).squeeze(0)
                search_feat = rearrange(search_feat.view(1, 768, -1), 't c l -> t l c')
                search_feat = self.fcq1(search_feat)
                loc_feat = self.fusion_fc1(location_emb[i].unsqueeze(0))
                loc_feat = self.fc1(rearrange(loc_feat, 't l c -> t c l'))
                msim_l = torch.matmul(search_feat, loc_feat) / 28
                msim_l = F.softmax(msim_l, dim=-1)
                #mu = msim_l.mean(dim=-1, keepdim=True)  # (batch, L_q, 1)
                mid, _ = torch.median(msim_l, dim=-1, keepdim=True)
                sigma = msim_l.std(dim=-1, keepdim=True)
                tau=mid+sigma/2
                gate=torch.sigmoid(50 * (msim_l - tau))

                msim_l=msim_l*gate
                msim_l = msim_l / (msim_l.sum(dim=-1, keepdim=True) + 1e-8)


                fused_loc = torch.matmul(msim_l, loc_feat.transpose(-1, -2))+loc_feat.transpose(-1, -2)


            else:
                fused_loc = fused_color

            fused_nlp = torch.cat([class_feat, fused_color, fused_act, fused_loc], 1)

            res_feat = self.fusion_visual_textual(
                query=opt_feat,
                key=fused_nlp,
                value=fused_nlp,
            )[0]
            opt_feat = opt_feat + res_feat
            # vis_feat = rearrange(vis_feat, 'l bt c -> bt c l')
            opt_feat = rearrange(opt_feat, 't c l -> t l c').view(-1, C, self.feat_sz_s, self.feat_sz_s)
            opt_feat_class = torch.cat([opt_feat_class, opt_feat], 0)
            attr_f = fused_nlp.view(1, 1, -1, 768).mean(dim=2).detach()
            attr_feat = torch.cat([attr_feat, attr_f], 0)

        opt_feat = opt_feat_class
        if self.head_type == "CORNER":
            # run the corner head
            pred_box, score_map = self.box_head(opt_feat, True)
            outputs_coord = box_xyxy_to_cxcywh(pred_box)
            outputs_coord_new = outputs_coord.view(bs, Nq, 4)
            out = {'pred_boxes': outputs_coord_new,
                   'score_map': score_map,
                   }
            return out, attr_feat

        elif self.head_type == "CENTER":
            # run the center head
            score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat, gt_score_map)

            # outputs_coord = box_xyxy_to_cxcywh(bbox)
            outputs_coord = bbox
            outputs_coord_new = outputs_coord.view(bs, Nq, 4)

            out = {'pred_boxes': outputs_coord_new,
                   'score_map': score_map_ctr,
                   'size_map': size_map,
                   'offset_map': offset_map}

            return out, attr_feat
        else:
            raise NotImplementedError


def build_atstrack(cfg, training=True):
    current_dir = os.path.dirname(os.path.abspath(__file__))  # This is your Project Root
    pretrained_path = os.path.join(current_dir, '../../../pretrained_networks')
    if cfg.MODEL.PRETRAIN_FILE and ('OSTrack' not in cfg.MODEL.PRETRAIN_FILE) and training:
        pretrained = os.path.join(pretrained_path, cfg.MODEL.PRETRAIN_FILE)
    else:
        pretrained = ''

    if cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224':
        backbone = vit_base_patch16_224(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE,
                                        add_cls_token=cfg.MODEL.BACKBONE.ADD_CLS_TOKEN,
                                        attn_type=cfg.MODEL.BACKBONE.ATTN_TYPE, )

    elif cfg.MODEL.BACKBONE.TYPE == 'vit_large_patch16_224':
        backbone = vit_large_patch16_224(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE,
                                         add_cls_token=cfg.MODEL.BACKBONE.ADD_CLS_TOKEN,
                                         attn_type=cfg.MODEL.BACKBONE.ATTN_TYPE,
                                         )

    elif cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224_ce':
        backbone = vit_base_patch16_224_ce(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE,
                                           ce_loc=cfg.MODEL.BACKBONE.CE_LOC,
                                           ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO,
                                           add_cls_token=cfg.MODEL.BACKBONE.ADD_CLS_TOKEN,
                                           )

    elif cfg.MODEL.BACKBONE.TYPE == 'vit_large_patch16_224_ce':
        backbone = vit_large_patch16_224_ce(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE,
                                            ce_loc=cfg.MODEL.BACKBONE.CE_LOC,
                                            ce_keep_ratio=cfg.MODEL.BACKBONE.CE_KEEP_RATIO,
                                            add_cls_token=cfg.MODEL.BACKBONE.ADD_CLS_TOKEN,
                                            )
    elif cfg.MODEL.BACKBONE.TYPE == 'hivit_base':
        backbone = fast_itpn_base_3324_patch16_224(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE)
        #backbone = hivit_base(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE)

    else:
        raise NotImplementedError
    hidden_dim = backbone.embed_dim
    patch_start_index = 1

    backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index)

    box_head = build_box_head(cfg, hidden_dim)

    model = atstrack(
        backbone,
        box_head,
        aux_loss=False,
        head_type=cfg.MODEL.HEAD.TYPE,
        token_len=cfg.MODEL.BACKBONE.TOKEN_LEN,
    )

    return model
