from . import BaseActor
from lib.utils.misc import NestedTensor
from lib.utils.box_ops import box_cxcywh_to_xyxy, box_xywh_to_xyxy
import torch
from lib.utils.merge import merge_template_search
from ...utils.heapmap_utils import generate_heatmap
from ...utils.ce_utils import generate_mask_cond, adjust_keep_rate
import torch
import torch.nn.functional as F
import random


class OSTrackActor(BaseActor):
    """ Actor for training OSTrack models """

    def __init__(self, net, objective, loss_weight, settings, cfg=None):
        super().__init__(net, objective)
        self.loss_weight = loss_weight
        self.settings = settings
        self.bs = self.settings.batchsize  # batch size
        self.cfg = cfg
        
    
    def __call__(self, data):
        """
        args:
            data - The input data, should contain the fields 'template', 'search', 'gt_bbox'.
            template_images: (N_t, batch, 3, H, W)
            search_images: (N_s, batch, 3, H, W)
        returns:
            loss    - the training loss
            status  -  dict containing detailed losses
        """

        out_dict_prev, out_dict_cur = self.forward_pass(data)

        # compute losses
        loss, status = self.compute_losses(out_dict_prev, out_dict_cur, data)

        return loss, status

    def forward_pass(self, data):
        # currently only support 1 template and 1 search region
        # assert len(data['template_images']) == 1
        # assert len(data['search_images']) == 1
    
        template_list, search_list = [], []
        for i in range(self.settings.num_template):
            template_img_i = data['template_images'][i].view(-1, *data['template_images'].shape[2:])  # (batch, 3, 128, 128)
            template_list.append(template_img_i)

            gt = torch.round(box_xywh_to_xyxy(data['template_anno']) * template_img_i.shape[-1]).squeeze(0)
            weight = torch.zeros_like(template_img_i, dtype=torch.float32)

            for j in range(weight.shape[0]):
                x1, y1, x2, y2 = gt[j]
                weight[j, :, int(y1.item()):int(y2.item()), int(x1.item()):int(x2.item())] = 1.0

        if len(template_list) == 1:
            template_list = template_list[0]

        for i in range(self.settings.num_search):
            search_img_i = data['search_images'][i].view(-1, *data['search_images'].shape[2:])
            search_list.append(search_img_i)  # (batch, 3, 320, 320)
        # search_att = data['search_att'][0].view(-1, *data['search_att'].shape[2:])  # (batch, 320, 320)

        if 'occl' in self.cfg['DATA']['SAMPLER_MODE'] and data['epoch'] > 30:
            gt_vis = data['visible']
            allow_invis = gt_vis.shape[0] // 2
            new_vis = gt_vis.clone().float()

            for col in range(new_vis.shape[1]):
                vis_data = new_vis[:, col]
                vis_indices = torch.nonzero(vis_data).squeeze().tolist()
                num_invis = gt_vis.shape[0] - len(vis_indices)
                ext_invis = allow_invis - num_invis
                invis_indices = random.sample(vis_indices, k=ext_invis)
                new_vis[invis_indices, col] = 0.0

                for row in invis_indices:
                    gt_bbox = torch.round(box_xywh_to_xyxy(data['search_anno'])[row, col, :] * search_img_i.shape[-1])
                    x1, y1, x2, y2 = gt_bbox
                    search_list[row][col][:, int(y1.item()):int(y2.item()), int(x1.item()):int(x2.item())] = 0.0

        out_prev, out_cur = self.net(template=template_list, search_list=search_list, weight=weight)

        return out_prev, out_cur

    def compute_losses(self, pred_dict_prev, pred_dict_cur, gt_dict, return_status=True):
        # gt gaussian map
        #gt_bbox = gt_dict['search_anno'][-1]  # (Ns, batch, 4) (x1,y1,w,h) -> (batch, 4)
        gt_bbox = gt_dict['search_anno'].view(-1,4)
        gts = gt_bbox.unsqueeze(0)
        gt_gaussian_maps = generate_heatmap(gts, self.cfg.DATA.SEARCH.SIZE, self.cfg.MODEL.BACKBONE.STRIDE)
        gt_gaussian_maps = gt_gaussian_maps[-1].unsqueeze(1)

        # Get boxes
        pred_boxes_prev = pred_dict_prev['pred_boxes']
        pred_boxes_cur = pred_dict_cur['pred_boxes']
        if torch.isnan(pred_boxes_prev).any():
            raise curueError("Network outputs is NAN! Stop Training")
        num_queries = pred_boxes_prev.size(1)
        pred_boxes_vec_prev = box_cxcywh_to_xyxy(pred_boxes_prev).view(-1, 4)  # (B,N,4) --> (BN,4) (x1,y1,x2,y2)
        pred_boxes_vec_cur = box_cxcywh_to_xyxy(pred_boxes_cur).view(-1, 4)  # (B,N,4) --> (BN,4) (x1,y1,x2,y2)
        gt_boxes_vec = box_xywh_to_xyxy(gt_bbox)[:, None, :].repeat((1, num_queries, 1)).view(-1, 4).clamp(min=0.0,
                                                                                                           max=1.0)  # (B,4) --> (B,1,4) --> (B,N,4)

        # compute giou and iou
        try:
            giou_loss_prev, iou_prev = self.objective['giou'](pred_boxes_vec_prev, gt_boxes_vec)  # (BN,4) (BN,4)
            giou_loss_cur, iou_cur = self.objective['giou'](pred_boxes_vec_cur, gt_boxes_vec)  # (BN,4) (BN,4)
        except:
            giou_loss_prev, iou_prev = torch.tensor(0.0).cuda(), torch.tensor(0.0).cuda()
            giou_loss_cur, iou_cur = torch.tensor(0.0).cuda(), torch.tensor(0.0).cuda()

        # compute l1 loss
        l1_loss_prev = self.objective['l1'](pred_boxes_vec_prev, gt_boxes_vec, reduction='none').mean(dim=1)  # (BN,4) (BN,4)
        l1_loss_cur = self.objective['l1'](pred_boxes_vec_cur, gt_boxes_vec, reduction='none').mean(dim=1)  # (BN,4) (BN,4)

        # compute location loss
        if 'score_map' in pred_dict_prev:
            location_loss_prev = self.objective['focal'](pred_dict_prev['score_map'], gt_gaussian_maps)
            location_loss_cur = self.objective['focal'](pred_dict_cur['score_map'], gt_gaussian_maps)
        else:
            location_loss_prev = torch.tensor(0.0, device=l1_loss.device)
            location_loss_cur = torch.tensor(0.0, device=l1_loss.device)

        # k = min(10, 1 + 0.15 * gt_dict['epoch'])
        
        # giou_loss = self.loss_weight['giou'] * ( giou_loss_prev + giou_loss_cur)
        # l1_loss = self.loss_weight['l1'] * ( l1_loss_prev + l1_loss_cur )
        # location_loss = self.loss_weight['focal'] * ( location_loss_prev + location_loss_cur ) 

        # soft_giou_loss = F.softplus(self.loss_weight['giou'] * k * ( giou_loss_cur - giou_loss_prev ))
        # soft_l1_loss = F.softplus(self.loss_weight['l1'] * k * ( l1_loss_cur - l1_loss_prev ))
        # soft_location_loss = F.softplus(self.loss_weight['focal'] * k * ( location_loss_cur - location_loss_prev ))

        giou_loss = self.loss_weight['giou'] * giou_loss_prev.mean() 
        l1_loss = self.loss_weight['l1'] * l1_loss_prev.mean() 
        location_loss = self.loss_weight['focal'] * location_loss_prev.sum() 

        soft_giou_loss = self.loss_weight['giou'] * (F.relu(giou_loss_cur - giou_loss_prev).mean())
        soft_l1_loss = self.loss_weight['l1'] * (F.relu(l1_loss_cur - l1_loss_prev).mean())
        soft_location_loss = self.loss_weight['focal'] * (F.relu(location_loss_cur - location_loss_prev).sum())


        # weighted sum
        soft_loss = soft_giou_loss + soft_l1_loss + soft_location_loss
        loss = giou_loss + l1_loss + location_loss + soft_loss 
        
        if return_status:
            # status for log
            mean_iou_prev = iou_prev.detach().mean()
            mean_iou_cur = iou_cur.detach().mean()

            status = {"Loss/total": loss.item(),
                      "Loss/giou": giou_loss_prev.mean().item(),
                      "Loss/l1": l1_loss_prev.mean().item(),
                      "Loss/location": location_loss_prev.sum().item(),
                      "Loss/soft": soft_loss.item(),
                      "IoU_prev": mean_iou_prev.item(),
                      "IoU_cur": mean_iou_cur.item()}
            return loss, status
        else:
            return loss
