from __future__ import absolute_import

import numpy as np
import os
import cv2
import torch
import torch.nn.functional as F
import math

from ltr.data.processing_utils import *
from ltr.utils.box_helper import clip_box
from ltr.models.tracking.model import tracker_model

from transformers import BertTokenizer

class STTracker(object):

    def __init__(self, name, config, dataset_name, gpu_id=0):
        self.name = name
        self.config = config
        self.dataset_name = dataset_name
        self.exemplar_size = config.train.template_size
        self.instance_size = config.train.search_size
        self.device = torch.device("cuda", gpu_id)  # torch.device("cuda")

        self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 3, 1, 1)).to(self.device)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 3, 1, 1)).to(self.device)
        
        self.model_type = config.model.name
        self.debug = config.test.debug
        self.save_all_boxes = config.test.save_all_boxes
        self.save_dir = os.path.join(config.common.exp_dir,'vis')

        self.update_intervals = config.test.update_intervals[dataset_name]

        self.network = tracker_model(config, self.device, test=True)
        state = torch.load(config.common.test_model, map_location='cpu')
        if 'net' in state.keys():
            state = state['net']
        miss, _ = self.network.load_state_dict(state, strict=False)
        try:
            assert len(miss)==0
        except Exception as e:
            print("missing keys in test model", miss)
        self.network = self.network.to(self.device)
        self.network.eval()

        # for nlp
        self.tokenizer = BertTokenizer.from_pretrained("/nfs/users/gejiawei/STTracker/pretrained/bert-base-uncased/")
        self.bert_seq_length = config.model.bert.max_len


    def process(self, img_arr: np.ndarray):
        # Deal with the image patch
        img_tensor = torch.tensor(img_arr).to(self.device).float().permute((2,0,1)).unsqueeze(dim=0)
        img_tensor_norm = ((img_tensor / 255.0) - self.mean) / self.std  # (1,3,H,W)
        return img_tensor_norm

    def tokenize_text(self, text: str):
        encoded_inputs = self.tokenizer(text, max_length=self.bert_seq_length, padding='max_length', truncation=True)
        self.text_ids = torch.LongTensor(encoded_inputs['input_ids']).to(self.device).unsqueeze(dim=0)
        self.text_mask = torch.LongTensor(encoded_inputs['attention_mask']).to(self.device).unsqueeze(dim=0)
    

    def initialize(self, image, info: dict):
        # forward the template once
        z_patch_arr, _, z_amask_arr = sample_target(image, info['init_bbox'], self.config.train.template_area_factor, output_sz=self.config.train.template_size)
        template = self.process(z_patch_arr)
        self.template = template.to(self.device)

        self.online_template = template.to(self.device)
        self.state = info['init_bbox']
        self.frame_id = 0


    def track(self, image, info: dict = None):
        
        self.frame_id += 1 
        if self.debug:
        # for visualization, set image to 320*320
            image=cv2.resize(image,(320,320))
            H, W, _ = image.shape
            x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.config.train.search_area_factor, output_sz=self.config.train.search_size)  # (x1, y1, w, h)
            search = self.process(image).to(self.device)

        else:
            H, W, _ = image.shape
            x_patch_arr, resize_factor, x_amask_arr = sample_target(image, self.state, self.config.train.search_area_factor, output_sz=self.config.train.search_size)  # (x1, y1, w, h)
            search = self.process(x_patch_arr).to(self.device)

        with torch.no_grad():
            out_dict = self.network.forward_test(self.template, self.online_template, search, self.text_ids, self.text_mask)

        pred_boxes = out_dict['pred_boxes'].view(-1, 4)
        # Baseline: Take the mean of all pred boxes as the final result
        pred_box = (pred_boxes.mean(dim=0) * self.config.train.search_size / resize_factor).tolist()  # (cx, cy, w, h) [0,1]
        # get the final box result
        self.state = clip_box(self.map_box_back(pred_box, resize_factor), H, W, margin=10)

        # update template
        for idx, update_i in enumerate(self.update_intervals):
            if self.frame_id % update_i == 0:
                z_patch_arr, _, z_amask_arr = sample_target(image, self.state, self.config.train.template_area_factor,
                                                            output_sz=self.config.train.template_size)  # (x1, y1, w, h)
                self.online_template = self.process(z_patch_arr)

        # for debug
        if self.debug:
            x1, y1, w, h = self.state
            image_BGR = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            #print(x_patch_arr.shape)
            save_path = os.path.join(self.save_dir, "origin%04d.jpg" % self.frame_id)
            cv2.imwrite(save_path, image_BGR)

            for i in range(3):
                mean_s = out_dict['attens'][i]
                a_h = int(math.sqrt(mean_s.shape[1]))
                mean_s = mean_s.reshape(1,1,a_h,a_h)
                weight = torch.nn.functional.interpolate(mean_s, scale_factor=int(320/a_h), mode='bilinear')
                weight = (weight - weight.min()) / (weight.max() - weight.min())           
                weight = weight.reshape(320,320).cpu().numpy()
                heatmap = cv2.applyColorMap(np.uint8(255*weight), cv2.COLORMAP_JET)
                intensity = 0.5
                att_image = heatmap * intensity + cv2.resize(image_BGR,(320,320))
                str_p = str(self.frame_id)+ '_stage' + str(i) + '.jpg'
                save_path = os.path.join(self.save_dir,str_p)
                cv2.imwrite(save_path, att_image)


        if self.save_all_boxes:
            '''save all predictions'''
            all_boxes = self.map_box_back_batch(pred_boxes * self.config.train.search_size / resize_factor, resize_factor)
            all_boxes_save = all_boxes.view(-1).tolist()  # (4N, )
            return {"target_bbox": self.state,
                    "all_boxes": all_boxes_save}
        else:
            return {"target_bbox": self.state}

    def map_box_back(self, pred_box: list, resize_factor: float):
        cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
        cx, cy, w, h = pred_box
        half_side = 0.5 * self.config.train.search_size / resize_factor
        cx_real = cx + (cx_prev - half_side)
        cy_real = cy + (cy_prev - half_side)
        return [cx_real - 0.5 * w, cy_real - 0.5 * h, w, h]

    def map_box_back_batch(self, pred_box: torch.Tensor, resize_factor: float):
        cx_prev, cy_prev = self.state[0] + 0.5 * self.state[2], self.state[1] + 0.5 * self.state[3]
        cx, cy, w, h = pred_box.unbind(-1) # (N,4) --> (N,)
        half_side = 0.5 * self.config.train.search_size / resize_factor
        cx_real = cx + (cx_prev - half_side)
        cy_real = cy + (cy_prev - half_side)
        return torch.stack([cx_real - 0.5 * w, cy_real - 0.5 * h, w, h], dim=-1)