# coding: UTF-8
import torch
import torch.nn as nn
import numpy as np
from importlib import import_module
from ltr.utils.box_helper import box_cxcywh_to_xyxy, box_xywh_to_xyxy

from ltr.models.tracking.model import tracker_model
from ltr.models.loss.loss import giou_loss, ciou_loss
from torch.nn.functional import l1_loss

class Actor(nn.Module):
    def __init__(self, cfg, device=None):
        super(Actor, self).__init__()
        '''
        config from config.yaml
        '''
        self.net = tracker_model(cfg, device)

        self.objective = {'ciou': ciou_loss, 'l1': l1_loss, 'giou': giou_loss}
        self.loss_weight = {'ciou': cfg.train.ciou_weight, 'l1': cfg.train.l1_weight, 'giou': cfg.train.giou_weight}

        self.cfg = cfg

        
    def forward(self, data):
        
        out_dict = self.net(data)
        
        # (b, ns, 4) (x, y, w, h)
        search_bbox = data['search_bbox']           # 已到0-1之间
        # print(search_bbox)
        # import sys
        # sys.exit()
        loss, status = self.compute_losses(out_dict, search_bbox.squeeze(1))
        return loss, status

    def compute_losses(self, pred_dict, gt_bbox, return_status=True):
        # Get boxes
        pred_boxes = pred_dict['pred_boxes']
        if torch.isnan(pred_boxes).any():
            raise ValueError("Network outputs is NAN! Stop Training")
        num_queries = pred_boxes.size(1)
        pred_boxes_vec = box_cxcywh_to_xyxy(pred_boxes).view(-1, 4)  # (B,N,4) --> (BN,4) (x1,y1,x2,y2)
        gt_boxes_vec = box_xywh_to_xyxy(gt_bbox)[:, None, :].repeat((1, num_queries, 1)).view(-1, 4).clamp(min=0.0, max=1.0)  # (B,4) --> (B,1,4) --> (B,N,4)


        try:
            if self.cfg.model.head.name == 'mlp':
                giou_loss, iou = self.objective['ciou'](pred_boxes_vec, gt_boxes_vec)
            else:
                giou_loss, iou = self.objective['giou'](pred_boxes_vec, gt_boxes_vec)  # (BN,4) (BN,4)
        except:
            #print("error in ciou/giou loss")
            giou_loss, iou = torch.tensor(0.0).cuda(), torch.tensor(0.0).cuda()
        # compute l1 loss
        l1_loss = self.objective['l1'](pred_boxes_vec, gt_boxes_vec)  # (BN,4) (BN,4)

        # weighted sum
        if self.cfg.model.head.name == 'mlp':
            loss = self.loss_weight['ciou'] * giou_loss + self.loss_weight['l1'] * l1_loss
        else:
            loss = self.loss_weight['giou'] * giou_loss + self.loss_weight['l1'] * l1_loss



        if return_status:
            # status for log
            mean_iou = iou.detach().mean()
            status = {"total": loss.item(),
                      "giou": giou_loss.item(),
                      "l1": l1_loss.item(),
                      "IoU": mean_iou.item()}
            return loss, status
        else:
            return loss

if __name__ == "__main__":
    import sys
    from config import parse_args
    from utils import *
    from torchvision.transforms import ToPILImage
    import torchvision.transforms as transforms
    transf = transforms.ToTensor()
    args = parse_args()
    cfg = prepare_env(args, sys.argv)
    device = torch.device("cuda", int(cfg.common.gpus))
    actor = Actor(cfg, device).to(device)
    t = transf(ToPILImage()(torch.rand(3, 224, 224)))
    s = transf(ToPILImage()(torch.rand(3, 224, 224)))
    template = t.repeat(2,2,1,1,1).to(device)
    search = s.repeat(2,1,1,1).to(device)
    text = ["cat", "dog"]
    data={}
    data['template'] = template
    data['search'] = search
    data['text'] = text
    data['search_bbox'] = torch.rand(2,1,4).to(device)
    print(actor(data))