from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.vision_transformer import resize_pos_embed
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

from lib.models.layers.patch_embed import PatchEmbed
from lib.models.ostrack.utils import combine_tokens, recover_tokens


class BaseBackbone(nn.Module):
    def __init__(self):
        super().__init__()

        # for original ViT
        self.pos_embed = None
        self.img_size = [224, 224]
        self.patch_size = 16
        self.embed_dim = 384

        self.cat_mode = 'direct'

        self.pos_embed_x = None

        self.template_segment_pos_embed = None
        self.search_segment_pos_embed = None

        self.return_inter = False
        self.return_stage = [2, 5, 8, 11]

        self.add_cls_token = False
        self.add_sep_seg = False

    def finetune_track(self, cfg, patch_start_index=1):

        search_size = to_2tuple(cfg.DATA.SEARCH.SIZE)
        new_patch_size = cfg.MODEL.BACKBONE.STRIDE

        # for patch embedding
        patch_pos_embed = self.pos_embed[:, patch_start_index:, :]
        patch_pos_embed = patch_pos_embed.transpose(1, 2)
        B, E, Q = patch_pos_embed.shape
        P_H, P_W = self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size
        patch_pos_embed = patch_pos_embed.view(B, E, P_H, P_W)

        # for search region
        H, W = search_size
        new_P_H, new_P_W = H // new_patch_size, W // new_patch_size
        search_patch_pos_embed = nn.functional.interpolate(patch_pos_embed, size=(new_P_H, new_P_W), mode='bicubic',
                                                           align_corners=False)
        search_patch_pos_embed = search_patch_pos_embed.flatten(2).transpose(1, 2)

        self.pos_embed_x = nn.Parameter(search_patch_pos_embed)


    # def forward_features(self, z, x, tgt):
    #     B, H, W = x.shape[0], x.shape[2], x.shape[3]

    #     x = self.patch_embed(x)
    #     z = self.patch_embed(z)

    #     if self.add_cls_token:
    #         cls_tokens = self.cls_token.expand(B, -1, -1)
    #         cls_tokens = cls_tokens + self.cls_pos_embed

    #     z += self.pos_embed_z
    #     x += self.pos_embed_x

    #     z += identity[:, 0, :].repeat(B, self.pos_embed_z0.shape[1], 1)
    #     x += identity[:, 1, :].repeat(B, self.pos_embed_z1.shape[1], 1)
 

    #     if self.add_sep_seg:
    #         x += self.search_segment_pos_embed
    #         z += self.template_segment_pos_embed

    #     x = combine_tokens(z, x, mode=self.cat_mode)
    #     if self.add_cls_token:
    #         x = torch.cat([cls_tokens, x], dim=1)

    #     x = self.pos_drop(x)

    #     for i, blk in enumerate(self.blocks):
    #         x = blk(x)

    #     lens_z = self.pos_embed_z.shape[1]
    #     lens_x = self.pos_embed_x.shape[1]
    #     x = recover_tokens(x, lens_z, lens_x, mode=self.cat_mode)

    #     aux_dict = {"attn": None}
    #     return self.norm(x), aux_dict

    # def forward(self, z, x, **kwargs):
    #     """
    #     Joint feature extraction and relation modeling for the basic ViT backbone.
    #     Args:
    #         z (torch.Tensor): template feature, [B, C, H_z, W_z]
    #         x (torch.Tensor): search region feature, [B, C, H_x, W_x]

    #     Returns:
    #         x (torch.Tensor): merged template and search region feature, [B, L_z+L_x, C]
    #         attn : None
    #     """
    #     x, aux_dict = self.forward_features(z, x,)

    #     return x, aux_dict
