from . import BaseActor
from lib.utils.misc import NestedTensor
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
import torch
from lib.utils.merge import merge_template_search
from ...utils.heapmap_utils import generate_heatmap
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate


class Mamba_FETrackActor(BaseActor):
    """ Actor for training Mamba_FETrack models """

    def __init__(self, net, objective, loss_weight, settings, cfg=None):
        super().__init__(net, objective)
        self.loss_weight = loss_weight
        self.settings = settings
        self.bs = self.settings.batchsize  # batch size
        self.cfg = cfg

    def __call__(self, data):
        """
        args:
            data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
            template_images: (N_t, batch, 3, H, W)
            search_images: (N_s, batch, 3, H, W)
        returns:
            loss    - the training loss
            status  -  dict containing detailed losses
        """
        # forward pass
        out_dict = self.forward_pass(data)

        # compute losses
        loss, status = self.compute_losses(out_dict, data)

        return loss, status

    def forward_pass(self, data):
        # currently only support 1 template and 1 search region
        assert len(data['template_images']) == 1
        assert len(data['search_images']) == 1
        assert len(data['template_event_images']) == 1
        assert len(data['search_event_images']) == 1

        # 获取设备信息
        device = next(self.net.parameters()).device

        # 调试信息输出，限制频率（只在每100个批次打印一次）
        # 使用静态变量跟踪批次计数
        if not hasattr(self, '_debug_counter'):
            self._debug_counter = 0
        
        self._debug_counter += 1
        debug_print = (self._debug_counter % 100 == 1)  # 每100个批次打印一次
        
        if debug_print:
            print("数据格式检查:")
            print(f"template_images 类型: {type(data['template_images'])}")
            print(f"template_images 维度: {data['template_images'].shape}")
            print(f"template_event_images 类型: {type(data['template_event_images'])}")
            print(f"template_event_images 维度: {data['template_event_images'].shape}")
            print(f"search_event_images 类型: {type(data['search_event_images'])}")
            print(f"search_event_images 维度: {data['search_event_images'].shape}")

        template_list = []
        for i in range(self.settings.num_template):
            template_img_i = data['template_images'][i].view(-1,
                                                             *data['template_images'].shape[2:])  # (batch, 3, 128, 128)
            # template_att_i = data['template_att'][i].view(-1, *data['template_att'].shape[2:])  # (batch, 128, 128)
            template_list.append(template_img_i)

        search_img = data['search_images'][0].view(-1, *data['search_images'].shape[2:])  # (batch, 3, 320, 320)
        # search_att = data['search_att'][0].view(-1, *data['search_att'].shape[2:])  # (batch, 320, 320)

        template_event = data['template_event_images'][0].view(-1, *data['template_event_images'].shape[2:])
        search_event = data['search_event_images'][0].view(-1, *data['search_event_images'].shape[2:])

        # 确保所有数据都是torch.Tensor类型
        if not isinstance(template_list[0], torch.Tensor):
            template_list[0] = torch.tensor(template_list[0], dtype=torch.float32)
        if not isinstance(search_img, torch.Tensor):
            search_img = torch.tensor(search_img, dtype=torch.float32)
        if not isinstance(template_event, torch.Tensor):
            template_event = torch.tensor(template_event, dtype=torch.float32)
        if not isinstance(search_event, torch.Tensor):
            search_event = torch.tensor(search_event, dtype=torch.float32)

        # 确保数据类型是float32
        template_list[0] = template_list[0].float()
        search_img = search_img.float()
        template_event = template_event.float()
        search_event = search_event.float()

        # 调整事件图像的维度，确保它们符合要求
        if template_event.dim() == 4 and template_event.shape[-1] == 3:
            # 如果形状是 [B, H, W, 3]，转换为 [B, 3, H, W]
            template_event = template_event.permute(0, 3, 1, 2)
        elif template_event.dim() == 3 and template_event.shape[-1] == 3:
            # 如果形状是 [H, W, 3]，转换为 [1, 3, H, W]
            template_event = template_event.permute(2, 0, 1).unsqueeze(0)

        if search_event.dim() == 4 and search_event.shape[-1] == 3:
            # 如果形状是 [B, H, W, 3]，转换为 [B, 3, H, W]
            search_event = search_event.permute(0, 3, 1, 2)
        elif search_event.dim() == 3 and search_event.shape[-1] == 3:
            # 如果形状是 [H, W, 3]，转换为 [1, 3, H, W]
            search_event = search_event.permute(2, 0, 1).unsqueeze(0)

        # 如果形状不满足要求，添加额外处理
        if template_event.dim() == 3:
            # 添加批次维度
            template_event = template_event.unsqueeze(0)
        if search_event.dim() == 3:
            # 添加批次维度
            search_event = search_event.unsqueeze(0)

        # 将数据移动到正确的设备上
        if len(template_list) == 1:
            template_list = template_list[0].to(device)
        else:
            template_list = [t.to(device) for t in template_list]
        search_img = search_img.to(device)
        template_event = template_event.to(device)
        search_event = search_event.to(device)

        # 输出最终的形状信息（同样限制频率）
        if debug_print:
            print(f"最终template_img形状: {template_list.shape if isinstance(template_list, torch.Tensor) else [t.shape for t in template_list]}")
            print(f"最终search_img形状: {search_img.shape}")
            print(f"最终template_event形状: {template_event.shape}")
            print(f"最终search_event形状: {search_event.shape}")
        
        box_mask_z = None
        ce_keep_rate = None
        if self.cfg.MODEL.BACKBONE.CE_LOC:
            box_mask_z = generate_mask_cond(self.cfg, template_list.shape[0] if isinstance(template_list, torch.Tensor) else template_list[0].shape[0], 
                                            template_list.device if isinstance(template_list, torch.Tensor) else template_list[0].device,
                                            data['template_anno'][0])

            ce_start_epoch = self.cfg.TRAIN.CE_START_EPOCH
            ce_warm_epoch = self.cfg.TRAIN.CE_WARM_EPOCH
            ce_keep_rate = adjust_keep_rate(data['epoch'], warmup_epochs=ce_start_epoch,
                                                total_epochs=ce_start_epoch + ce_warm_epoch,
                                                ITERS_PER_EPOCH=1,
                                                base_keep_rate=self.cfg.MODEL.BACKBONE.CE_KEEP_RATIO[0])

        # 检查网络是否使用动态状态转移矩阵 (同样限制频率)
        if debug_print:
            if hasattr(self.net, 'backbone') and hasattr(self.net.backbone, 'use_dynamic_a'):
                use_dynamic_a = self.net.backbone.use_dynamic_a
                print(f"模型使用动态状态转移矩阵: {use_dynamic_a}")
            else:
                print("模型不使用动态状态转移矩阵")

        # 调用网络前向传播
        out_dict = self.net(template=template_list,
                            search=search_img,
                            event_template=template_event,
                            event_search=search_event,
                            template_density=data['template_density'][0],
                            search_density=data['search_density'][0],
                            ce_template_mask=box_mask_z,
                            ce_keep_rate=ce_keep_rate,
                            return_last_attn=False)
        
        return out_dict

    def compute_losses(self, pred_dict, gt_dict, return_status=True):
        
        total_status = {}
        # gt gaussian map
        gt_bbox = gt_dict['search_anno'][-1]  # (Ns, batch, 4) (x1,y1,w,h) -> (batch, 4)
        gt_gaussian_maps = generate_heatmap(gt_dict['search_anno'], self.cfg.DATA.SEARCH.SIZE, self.cfg.MODEL.BACKBONE.STRIDE)
        gt_gaussian_maps = gt_gaussian_maps[-1].unsqueeze(1)

        # Get boxes
        pred_boxes = pred_dict['pred_boxes']
        if torch.isnan(pred_boxes).any():
            raise ValueError("Network outputs is NAN! Stop Training")
        num_queries = pred_boxes.size(1)
        pred_boxes_vec = box_cxcywh_to_xyxy(pred_boxes).view(-1, 4)  # (B,N,4) --> (BN,4) (x1,y1,x2,y2)
        gt_boxes_vec = box_xywh_to_xyxy(gt_bbox)[:, None, :].repeat((1, num_queries, 1)).view(-1, 4).clamp(min=0.0,
                                                                                                           max=1.0)  # (B,4) --> (B,1,4) --> (B,N,4)
        # compute giou and iou
        try:
            giou_loss, iou = self.objective['giou'](pred_boxes_vec, gt_boxes_vec)  # (BN,4) (BN,4)
        except:
            giou_loss, iou = torch.tensor(0.0).cuda(), torch.tensor(0.0).cuda()
        # compute l1 loss
        l1_loss = self.objective['l1'](pred_boxes_vec, gt_boxes_vec)  # (BN,4) (BN,4)
        # compute location loss
        if 'score_map' in pred_dict:
            location_loss = self.objective['focal'](pred_dict['score_map'], gt_gaussian_maps)
        else:
            location_loss = torch.tensor(0.0, device=l1_loss.device)
        # weighted sum
        loss = self.loss_weight['giou'] * giou_loss + self.loss_weight['l1'] * l1_loss + self.loss_weight['focal'] * location_loss
        if return_status:
            # status for log
            mean_iou = iou.detach().mean()
            status = {"Loss/total": loss.item(),
                      "Loss/giou": giou_loss.item(),
                      "Loss/l1": l1_loss.item(),
                      "Loss/location": location_loss.item(),
                      "IoU": mean_iou.item()}
            
            # 在验证阶段添加额外指标
            # 通过检查当前网络是否处于训练模式来判断
            is_validation = not self.net.training
                
            if is_validation:  # 验证阶段
                # 只保留Precision_Error_20px
                precision_metrics = self._compute_precision_metrics(pred_boxes_vec, gt_boxes_vec, iou)
                # 只提取20px精度
                if 'Precision_Error_20px' in precision_metrics:
                    status['Precision_Error_20px'] = precision_metrics['Precision_Error_20px']
                
            total_status.update(status)

            return loss, total_status
        else:
            return loss

    def _compute_precision_metrics(self, pred_boxes, gt_boxes, iou):
        """只计算Precision_Error_20px"""
        metrics = {}
        # 计算中心点距离
        pred_centers = (pred_boxes[:, :2] + pred_boxes[:, 2:]) / 2
        gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2
        center_distances = torch.norm(pred_centers - gt_centers, dim=1)
        img_size = 256  # 可根据实际图像尺寸调整
        pixel_distances = center_distances * img_size
        # 只计算20px阈值下的精度
        threshold = 20.0
        success_count = (pixel_distances <= threshold).float().sum().item()
        total_count = pixel_distances.numel()
        precision_rate = success_count / (total_count + 1e-8)
        metrics['Precision_Error_20px'] = precision_rate
        return metrics