"""
Basic OSTrack model.
"""
import math
import os
from typing import List

import torch
from torch import nn
from torch.nn.modules.transformer import _get_clones

from lib.models.layers.head import build_box_head
from lib.models.ostrack.vit import vit_large_patch16_224, vit_base_patch16_224
from lib.utils.box_ops import box_xyxy_to_cxcywh
import math
from lib.models.layers.position_encoding import build_position_encoding

from timm.models.layers import DropPath, to_2tuple, trunc_normal_


class OSTrack(nn.Module):
    """ This is the base class for OSTrack """

    def __init__(self, transformer, box_head, query_embed4tgt, query_embed4ctx, pos_embed4tgt, pos_embed4ctx, identity, aux_loss=False, head_type="CORNER"):
        """ Initializes the model.
        Parameters:
            transformer: torch module of the transformer architecture.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.backbone = transformer
        self.box_head = box_head

        self.query_embed4tgt = query_embed4tgt
        self.query_embed4ctx = query_embed4ctx
        self.pos_embed4tgt = pos_embed4tgt
        self.pos_embed4ctx = pos_embed4ctx
        self.identity = identity

        self.aux_loss = aux_loss
        self.head_type = head_type
        if head_type == "CORNER" or head_type == "CENTER":
            self.feat_sz_s = int(box_head.feat_sz)
            self.feat_len_s = int(box_head.feat_sz ** 2)

        if self.aux_loss:
            self.box_head = _get_clones(self.box_head, 6)

    
    def initialize_info(self, B):

        self.info_prev = None
        self.info_cur = None
        self.info_query = None

        tgt = torch.zeros_like(self.query_embed4tgt)
        ctx = torch.zeros_like(self.query_embed4ctx)

        pos_tgt = self.pos_embed4tgt(1)
        pos_ctx = self.pos_embed4ctx(1)

        tgt += pos_tgt
        ctx += pos_ctx

        tgt = tgt.reshape(1, -1, tgt.shape[-1])
        ctx = ctx.reshape(1, -1, tgt.shape[-1])

        tgt += self.identity[:, 0, :].repeat(1, tgt.shape[1], 1)
        ctx += self.identity[:, 1, :].repeat(1, ctx.shape[1], 1)

        tgt = tgt.expand(B, -1, -1)
        ctx = ctx.expand(B, -1, -1)

        return [tgt, ctx]


    def forward_track(self, info, info_query, search_list):

        info, x, attn_tgt, attn_ctx = self.backbone(info=info, info_query=info_query, x=search_list, identity=self.identity)

        sqrt_tgt = int(math.sqrt(attn_tgt.shape[-1]))
        attn_tgt = attn_tgt.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).reshape(1, 1, sqrt_tgt, sqrt_tgt)

        sqrt_ctx = int(math.sqrt(attn_ctx.shape[-1]))
        attn_ctx = attn_ctx.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).reshape(1, 1, sqrt_ctx, sqrt_ctx)

        # Forward head
        feat_last = x
        if isinstance(x, list):
            feat_last = x[-1]
        out = self.forward_head(feat_last, None)

        # out.update(aux_dict)
        out['backbone_feat'] = x
        out['info'] = info
        out['attn_tgt'] = attn_tgt
        out['attn_ctx'] = attn_ctx
        return out


    def forward(self, template,
                search_list,
                weight,
                ):

        self.info_query = self.initialize_info(template.shape[0])
        info_tgt, x_, attn_tgt_, attn_ctx_ = self.backbone(info=[x.clone().detach() for x in self.info_query], info_query=self.info_query, x=template*weight, identity=self.identity)
        return 0
        info_ctx, x_, attn_tgt_, attn_ctx_ = self.backbone(info=[x.clone().detach() for x in self.info_query], info_query=self.info_query, x=template*(1-weight), identity=self.identity)
        self.info_prev = [info_tgt[0].clone().detach(), info_ctx[1].clone().detach()]
        out_dict_prev = []
        out_dict_cur = []
 
        for i, search in enumerate(search_list): 
            
            if self.info_cur == None:
                info_cur, x, attn_tgt, attn_ctx = self.backbone(info=self.info_prev, info_query=self.info_query, x=search, identity=self.identity)
                self.info_cur = [x.clone().detach() for x in info_cur]
                out_dict_prev.append(x)
                out_dict_cur.append(x)
            
            else:
                info_, x_prev, attn_tgt, attn_ctx = self.backbone(info=self.info_prev, info_query=self.info_query, x=search, identity=self.identity)
                info_out, x_cur, attn_tgt, attn_ctx = self.backbone(info=self.info_cur, info_query=self.info_query, x=search, identity=self.identity)
                self.info_prev = [x.clone().detach() for x in self.info_cur]
                self.info_cur = [x.clone().detach() for x in info_out]
                out_dict_prev.append(x_prev)
                out_dict_cur.append(x_cur)
            
        x_prev = torch.cat(out_dict_prev, dim=0)
        x_cur = torch.cat(out_dict_cur, dim=0)
        out_prev = self.forward_head(x_prev)
        out_cur = self.forward_head(x_cur)

        # out.update(aux_dict)
        return out_prev, out_cur
        

    def forward_head(self, enc_opt, gt_score_map=None):
        """
        cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C)
        """
        opt = (enc_opt.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous()
        bs, Nq, C, HW = opt.size()
        opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s)

        if self.head_type == "CORNER":
            # run the corner head
            pred_box, score_map = self.box_head(opt_feat, True)
            outputs_coord = box_xyxy_to_cxcywh(pred_box)
            outputs_coord_new = outputs_coord.view(bs, Nq, 4)
            out = {'pred_boxes': outputs_coord_new,
                   'score_map': score_map,
                   }
            return out

        elif self.head_type == "CENTER":
            # run the center head
            score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat, gt_score_map)
            # outputs_coord = box_xyxy_to_cxcywh(bbox)
            outputs_coord = bbox
            outputs_coord_new = outputs_coord.view(bs, Nq, 4)
            out = {'pred_boxes': outputs_coord_new,
                   'score_map': score_map_ctr,
                   'size_map': size_map,
                   'offset_map': offset_map}
            return out
        else:
            raise NotImplementedError


def build_ostrack(cfg, training=True):
    current_dir = os.path.dirname(os.path.abspath(__file__))  # This is your Project Root
    pretrained_path = os.path.join(current_dir, '../../../pretrained_models')
    if cfg.MODEL.PRETRAIN_FILE and ('OSTrack' not in cfg.MODEL.PRETRAIN_FILE) and training:
        pretrained = os.path.join(pretrained_path, cfg.MODEL.PRETRAIN_FILE)
    else:
        pretrained = ''

    if cfg.MODEL.BACKBONE.TYPE == 'vit_base_patch16_224':
        backbone = vit_base_patch16_224(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE,)
        hidden_dim = backbone.embed_dim
        patch_start_index = 1

    elif cfg.MODEL.BACKBONE.TYPE == 'vit_large_patch16_224':
        backbone = vit_large_patch16_224(pretrained, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE,)
        hidden_dim = backbone.embed_dim
        patch_start_index = 1

    else:
        raise NotImplementedError

    backbone.finetune_track(cfg=cfg, patch_start_index=patch_start_index)

    query_embed4tgt = nn.Parameter(torch.empty(cfg.MODEL.ST_REPRESENTATION.TARGET_SIZE, cfg.MODEL.ST_REPRESENTATION.TARGET_SIZE, hidden_dim))
    query_embed4ctx = nn.Parameter(torch.empty(cfg.MODEL.ST_REPRESENTATION.CONTEXT_SIZE, cfg.MODEL.ST_REPRESENTATION.CONTEXT_SIZE, hidden_dim))

    pos_embed4tgt = build_position_encoding(cfg.MODEL.ST_REPRESENTATION.TARGET_SIZE, hidden_dim)
    pos_embed4ctx = build_position_encoding(cfg.MODEL.ST_REPRESENTATION.CONTEXT_SIZE, hidden_dim)

    identity = torch.nn.Parameter(torch.zeros(1, 3, hidden_dim))
    identity = trunc_normal_(identity, std=.02)        

    box_head = build_box_head(cfg, hidden_dim)

    model = OSTrack(
        backbone,
        box_head,
        query_embed4tgt,
        query_embed4ctx,
        pos_embed4tgt,
        pos_embed4ctx,
        identity,
        aux_loss=False,
        head_type=cfg.MODEL.HEAD.TYPE,
    )

    if 'OSTrack' in cfg.MODEL.PRETRAIN_FILE and training:
        pretrained_path = os.path.join(current_dir, '../../../pretrained_models/', cfg.MODEL.PRETRAIN_FILE)
        checkpoint = torch.load(pretrained_path, map_location="cpu")
        missing_keys, unexpected_keys = model.load_state_dict(checkpoint["net"], strict=False)
        print('Load pretrained model from: ' + pretrained_path)

    return model
