"""
Basic mamda_fetrack model.
"""
import math
import os
from typing import List

import torch
from torch import nn
from torch.nn.modules.transformer import _get_clones
import torch.nn.functional as F

from lib.models.layers.head import build_box_head
from lib.utils.box_ops import box_xyxy_to_cxcywh
from lib.models.mamba_fetrack.models_mamba import create_block
from timm.models import create_model
import torch.nn.functional as F
from lib.models.mamba_fetrack.mamba_cross import CrossMamba
from thop import profile


class Mamba_FEtrack(nn.Module):
    """ This is the base class for mamda_fetrack """

    def __init__(self, visionmamba, cross_mamba, box_head, aux_loss=False, head_type="CORNER"):
        """ Initializes the model.
        Parameters:
            transformer: torch module of the transformer architecture.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.backbone = visionmamba
        self.cross_mamba= cross_mamba
        self.box_head = box_head
        self.hidden_dim = visionmamba.embed_dim

        self.gate_linear = nn.Linear(2, 1)
        self.delta_fc1 = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.delta_fc2 = nn.Linear(self.hidden_dim, self.hidden_dim)

        self.fusion_output_proj = nn.Linear(self.hidden_dim, self.hidden_dim * 2)

        # 添加投影层，将384通道投影到768通道
        self.rgb_projection = nn.Linear(self.hidden_dim, self.hidden_dim * 2)

        self.aux_loss = aux_loss
        self.head_type = head_type
        if head_type == "CORNER" or head_type == "CENTER":
            self.feat_sz_s = int(box_head.feat_sz)
            self.feat_len_s = int(box_head.feat_sz ** 2)

        if self.aux_loss:
            self.box_head = _get_clones(self.box_head, 6)

    def feature_fusion(self, event_feature: torch.Tensor, rgb_feature: torch.Tensor, rho_t: torch.Tensor) -> torch.Tensor:
        """
        Applies RGB-Event feature fusion.
        F_Fusion = F_Event + G * DeltaF
        G = Sigmoid( W_g * [rho(t); ||F_RGB||_2] )
        DeltaF = W_2 * GELU(W_1 * F_RGB + b1) + b2

        Args:
            event_feature (torch.Tensor): Event features (B, N, C).
            rgb_feature (torch.Tensor): RGB features (B, N, C).
            rho_t (torch.Tensor): Combined event density (B,).

        Returns:
            torch.Tensor: Fused features (B, N, C).
        """
        # 1. 计算 G = Sigmoid( W_g * [rho(t); ||F_RGB||_2] )
        rgb_norm = torch.linalg.norm(rgb_feature.float(), ord=2, dim=(1, 2)) # shape: (B,)
        # 确保 rho_t 和 rgb_norm 在同一设备上
        device = rgb_norm.device
        rho_t = rho_t.to(device)
        # 确保 rho_t 的数据类型为 float32
        rho_t = rho_t.float()
        if rho_t.ndim == 1:
            rho_t = rho_t.unsqueeze(1) # Ensure shape (B, 1)
        gate_input_features = torch.cat((rho_t, rgb_norm.unsqueeze(1)), dim=1) # shape: (B, 2)
        g_logit = self.gate_linear(gate_input_features)
        G = torch.sigmoid(g_logit).unsqueeze(-1)

        delta_f_hidden = F.gelu(self.delta_fc1(rgb_feature))
        DeltaF = self.delta_fc2(delta_f_hidden)

        fused_feature = event_feature + G * DeltaF

        return fused_feature

    def forward(self, template: torch.Tensor,
                search: torch.Tensor,
                event_template: torch.Tensor,       
                event_search: torch.Tensor, 
                template_density: torch.Tensor,
                search_density: torch.Tensor,         
                ce_template_mask=None,
                ce_keep_rate=None,
                return_last_attn=False,
                ):
      
        
        rgb_feature = self.backbone.forward_features( z=template, x=search,                                                                    #[B, 320, 384]
                                                inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False)
        event_feature = self.backbone.forward_features_event(z=event_template, x=event_search, z_density=template_density, x_density=search_density,                                                       #[B, 320, 384]
                                                inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False)
        

        # if template_density.ndim > 1:
        #     template_density = template_density.view(-1)
        # if search_density.ndim > 1:
        #     search_density = search_density.view(-1)
        
        # rho_t = (template_density + search_density) / 2
        # rho_frame = torch.ones_like(rho_t)

        # fused_feature_event = self.feature_fusion(event_feature, rgb_feature, rho_t)
        # fused_feature_frame = self.feature_fusion(rgb_feature, event_feature, rho_frame)

        # event_search = fused_feature_event[:, -self.feat_len_s:]
        # rgb_search = fused_feature_frame[:, -self.feat_len_s:]

        # x = torch.cat((event_search,rgb_search),dim=-1)
        
        # x = self.fusion_output_proj(event_search)

        # residual_event_f = 0
        # residual_rgb_f = 0
        # event_f = self.cross_mamba(event_feature,residual_event_f,rgb_feature) + event_feature
        # rgb_f = self.cross_mamba(rgb_feature,residual_rgb_f,event_feature) + rgb_feature
        
        event_searh = event_feature[:, -self.feat_len_s:]
        rgb_search = rgb_feature[:, -self.feat_len_s:]
        x = torch.cat((event_searh,rgb_search),dim=-1)
        
        # # 将rgb_feature作为任务头的输入
        # # 确保只取搜索区域的特征（最后feat_len_s个token）
        # x = event_feature[:, -self.feat_len_s:]  # 取最后feat_len_s个token作为搜索区域特征
        
        # # 使用投影层将384通道投影到768通道
        # x = self.rgb_projection(x)  # 从 [B, feat_len_s, 384] 投影到 [B, feat_len_s, 768]
        
        feat_last = x             
        if isinstance(x, list):
            feat_last = x[-1]
        out = self.forward_head(feat_last, None)
       
        out['backbone_feat'] = x
        return out

    def forward_head(self, cat_feature, gt_score_map=None):
        """
        cat_feature: output embeddings of the backbone, it can be (HW1+HW2, B, C) or (HW2, B, C)
        """
        search_feature = cat_feature
        opt = (search_feature.unsqueeze(-1)).permute((0, 3, 2, 1)).contiguous()           # opt.shape = torch.Size([B, 1, 384, 256])
        bs, Nq, C, HW = opt.size()
        opt_feat = opt.view(-1, C, self.feat_sz_s, self.feat_sz_s)                       # opt_feat.shape = torch.Size([B, 384, 16, 16])

        if self.head_type == "CORNER":
            # run the corner head
            pred_box, score_map = self.box_head(opt_feat, True)
            outputs_coord = box_xyxy_to_cxcywh(pred_box)
            outputs_coord_new = outputs_coord.view(bs, Nq, 4)
            out = {'pred_boxes': outputs_coord_new,
                   'score_map': score_map,
                   }
            return out

        elif self.head_type == "CENTER":
            # run the center head
           
            score_map_ctr, bbox, size_map, offset_map = self.box_head(opt_feat, gt_score_map)
            outputs_coord = bbox
            outputs_coord_new = outputs_coord.view(bs, Nq, 4)
            out = {'pred_boxes': outputs_coord_new,
                   'score_map': score_map_ctr,
                   'size_map': size_map,
                   'offset_map': offset_map}
            return out
        else:
            raise NotImplementedError


def build_mamba_fetrack(cfg, training=True):
    current_dir = os.path.dirname(os.path.abspath(__file__))  # This is your Project Root
    pretrained_path = os.path.join(current_dir, 'pretrained_models')
    if cfg.MODEL.PRETRAIN_FILE and ('Mamba_FETrack' not in cfg.MODEL.PRETRAIN_FILE) and training:
        pretrained = os.path.join(pretrained_path, cfg.MODEL.PRETRAIN_FILE)
    else:
        pretrained = ''
    
    backbone = create_model( model_name= cfg.MODEL.BACKBONE.TYPE, pretrained= pretrained, num_classes=1000,
            drop_rate=0.0, drop_path_rate=cfg.TRAIN.DROP_PATH_RATE, drop_block_rate=None, img_size=256
            )
    hidden_dim = backbone.embed_dim
    cross_mamba = CrossMamba(hidden_dim)
    box_head = build_box_head(cfg, hidden_dim * 2)
    model = Mamba_FEtrack(
        backbone,
        cross_mamba,
        box_head,
        aux_loss=False,
        head_type=cfg.MODEL.HEAD.TYPE,
    )
   
    if 'Mamba_FETrack' in cfg.MODEL.PRETRAIN_FILE and training:
        checkpoint = torch.load(cfg.MODEL.PRETRAIN_FILE, map_location="cpu")
        missing_keys, unexpected_keys = model.load_state_dict(checkpoint["net"], strict=False)
        print('Load pretrained model from: ' + cfg.MODEL.PRETRAIN_FILE)

    return model

   