import math

from lib.models.ostrack_seq import build_ostrack_seq
from lib.test.tracker.basetracker import BaseTracker
import torch

from lib.test.tracker.vis_utils import gen_visualization
from lib.test.utils.hann import hann2d
from lib.train.data.processing_utils import sample_target
# for debug
import cv2
import os

from lib.test.tracker.data_utils import Preprocessor
from lib.utils.box_ops import clip_box
from lib.utils.ce_utils import generate_mask_cond



class OSTrackSeq(BaseTracker):
    def __init__(self, params, dataset_name):
        super(OSTrackSeq, self).__init__(params)

        network = build_ostrack_seq(params.cfg, training=False)
        network.load_state_dict(torch.load(self.params.checkpoint, map_location='cpu')['net'], strict=True)
        self.network = network.cuda()
        self.network.eval()
        print(self.params.checkpoint)
        
        self.cfg = params.cfg
        
        self.preprocessor = Preprocessor()
        self.state = None

        self.feat_sz = self.cfg.TEST.SEARCH_SIZE // self.cfg.MODEL.BACKBONE.STRIDE
        # motion constrain
        self.output_window = hann2d(torch.tensor([self.feat_sz, self.feat_sz]).long(), centered=True).cuda()

        # for debug
        self.debug = params.debug
        self.use_visdom = params.debug
        self.frame_id = 0
        if self.debug:
            if not self.use_visdom:
                self.save_dir = "debug"
                if not os.path.exists(self.save_dir):
                    os.makedirs(self.save_dir)
            else:
                # self.add_hook()
                self._init_visdom(None, 1)
        # for save boxes from all queries
        self.save_all_boxes = params.save_all_boxes
        self.z_dict1 = {}
        

    def initialize(self, image, info, seq_name):
        # forward the template once
        z_patch_arr, resize_factor, z_amask_arr = sample_target(image, info['init_bbox'], self.params.template_factor,
                                                    output_sz=self.params.template_size)
        self.z_patch_arr = z_patch_arr
        template = self.preprocessor.process(z_patch_arr, z_amask_arr)
        weight = self.calculate_template_weight(template, info['init_bbox'], resize_factor, self.params.template_size, seq_name)
        with torch.no_grad():
            self.z_dict1 = template

        # save states
        self.state = info['init_bbox']
        self.frame_id = 0
        
        with torch.no_grad():

            self.info_query = self.network.initialize_info(1)
            self.info, _, _, _ = self.network.backbone(info=self.info_query, info_query=self.info_query, x=self.z_dict1.tensors+weight, identity=self.network.identity)
 
        if self.save_all_boxes:
            '''save all predicted boxes'''
            all_boxes_save = info['init_bbox'] * self.cfg.MODEL.NUM_OBJECT_QUERIES
            return {"all_boxes": all_boxes_save}

    def calculate_template_weight(self, image, gt_bbox, resize_factor, crop_sz, seq_name):

        from lib.utils.misc import NestedTensor
        import lib.train.data.processing_utils as prutils
        from lib.utils.box_ops import box_xywh_to_xyxy
        from lib.test.tracker.vis_utils import draw_bbox

        box_in = torch.tensor(gt_bbox)
        box_out = prutils.transform_image_to_crop(box_in, box_in, resize_factor, crop_sz)
        template = image.tensors
        weight = torch.zeros_like(template, dtype=torch.float32)
        x1, y1, x2, y2 = torch.round(box_xywh_to_xyxy(box_out))
        weight[:, :, int(y1.item()):int(y2.item()), int(x1.item()):int(x2.item())] = 1.0

        # draw_bbox(seq_name, 0, self.z_patch_arr, box_out)

        return weight
        
        
    def track(self, image, info = None):
        H, W, _ = image.shape
        self.frame_id += 1
        x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.params.search_factor,
                                                                output_sz=self.params.search_size)  # (x1, y1, w, h)
        search = self.preprocessor.process(x_patch_arr, x_amask_arr)

        with torch.no_grad():
            
            x_dict = search

            out_dict = self.network.forward_track(
                    info=self.info, info_query=self.info_query, search_list=x_dict.tensors)
            
            self.info = out_dict['info']
            
        # add hann windows
        pred_score_map = out_dict['score_map']
        response = self.output_window * pred_score_map
        pred_boxes = self.network.box_head.cal_bbox(response, out_dict['size_map'], out_dict['offset_map'])
        pred_boxes = pred_boxes.view(-1, 4)
        # Baseline: Take the mean of all pred boxes as the final result
        pred_box = (pred_boxes.mean(
            dim=0) * self.params.search_size / resize_factor).tolist()  # (cx, cy, w, h) [0,1]
        # get the final box result
        self.state = clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)


        if self.save_all_boxes:
            '''save all predictions'''
            all_boxes = self.map_box_back_batch(pred_boxes * self.params.search_size / resize_factor, resize_factor)
            all_boxes_save = all_boxes.view(-1).tolist()  # (4N, )
            return {"target_bbox": self.state,
                    "all_boxes": all_boxes_save}
        else:
            return {"target_bbox": self.state,
                    "attn_tgt": out_dict['attn_tgt'],
                    "attn_ctx": out_dict['attn_ctx'],
                    "x": x_patch_arr}

    def map_box_back(self, pred_box: list, resize_factor: float):
        cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
        cx, cy, w, h = pred_box
        half_side = 0.5 * self.params.search_size / resize_factor
        cx_real = cx + (cx_prev - half_side)
        cy_real = cy + (cy_prev - half_side)
        return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]

    def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
        cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
        cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
        half_side = 0.5 * self.params.search_size / resize_factor
        cx_real = cx + (cx_prev - half_side)
        cy_real = cy + (cy_prev - half_side)
        return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)

    def add_hook(self):
        conv_features, enc_attn_weights, dec_attn_weights = [], [], []

        for i in range(12):
            self.network.backbone.blocks[i].attn.register_forward_hook(
                # lambda self, input, output: enc_attn_weights.append(output[1])
                lambda self, input, output: enc_attn_weights.append(output[1])
            )

        self.enc_attn_weights = enc_attn_weights


def get_tracker_class():
    return OSTrackSeq
