# -*- coding: utf-8 -*

from copy import deepcopy

import numpy as np
from PIL import Image
import torch, math, cv2
from torchvision.transforms import ColorJitter
from videoanalyst.pipeline.pipeline_base import TRACK_PIPELINES, PipelineBase
from videoanalyst.pipeline.utils import (cxywh2xywh, get_crop,
                                         get_subwindow_tracking,
                                         imarray_to_tensor, tensor_to_numpy,
                                         xywh2cxywh, xyxy2cxywh, cxywh2xyxy,)


# ============================== Tracker definition ============================== #
@TRACK_PIPELINES.register
class SiamFCppTrackerBaZ(PipelineBase):
    r"""
    Basic SiamFC++ tracker

    Hyper-parameters
    ----------------
        total_stride: int
            stride in backbone
        context_amount: float
            factor controlling the image patch cropping range. Set to 0.5 by convention.
        test_lr: float
            factor controlling target size updating speed
        penalty_k: float
            factor controlling the penalization on target size (scale/ratio) change
        window_influence: float
            factor controlling spatial windowing on scores
        windowing: str
            windowing type. Currently support: "cosine"
        z_size: int
            template image size
        x_size: int
            search image size
        num_conv3x3: int
            number of conv3x3 tiled in head
        min_w: float
            minimum width
        min_h: float
            minimum height
        phase_init: str
            phase name for template feature extraction
        phase_track: str
            phase name for target search
        corr_fea_output: bool
            whether output corr feature

    Hyper-parameters (to be calculated at runtime)
    ----------------------------------------------
    score_size: int
        final feature map
    score_offset: int
        final feature map
    """
    default_hyper_params = dict(
        total_stride=8,
        score_size=17,
        score_offset=87,
        context_amount=0.5,
        test_lr=0.52,
        penalty_k=0.04,
        window_influence=0.21,
        windowing="cosine",
        z_size=127,
        x_size=303,
        num_conv3x3=3,
        min_w=10,
        min_h=10,
        phase_init="feature",
        phase_track="track",
        corr_fea_output=False,
        pattern=0,
        t_size=0.1,
        jitter_defense = False,
        jitter_brightness = 0.0,
        jitter_contrast = 0.0,
        jitter_hue = 0.0,
        jitter_saturation = 0.0,
        gauss_defense = 0,
    )

    def __init__(self, *args, **kwargs):
        super(SiamFCppTrackerBaZ, self).__init__(*args, **kwargs)
        self.update_params()
        #print(self._hyper_params)
        # set underlying model to device
        self.device = torch.device("cpu")
        self.debug = False
        self.set_model(self._model)
        self.t_pattern = []
        #self.t_pattern = self.init_t_pattern(self._hyper_params['pattern'])
        #self.t_pattern = np.transpose(t_pattern, (2, 0, 1))

    def init_t_pattern_(self):
        #print('init pattern')
        pattern_type = self._hyper_params['pattern']
        t_pattern = np.zeros((4, 4, 3))
        if pattern_type == 1:
            t_pattern[:2, :2, :] = 1
            t_pattern[2:, 2:, :] = 1
            return t_pattern * 255
        elif pattern_type == 2:
            t_pattern[:2, :, :] = 1
            return t_pattern*255
        elif pattern_type == 3:
            t_pattern[0, 0, :] = 1
            t_pattern[1, :2, :] = 1
            t_pattern[2, :3, :] = 1
            t_pattern[3, :4, :] = 1
            return t_pattern*255
        elif pattern_type == 4:
            t_pattern[:] = 0
            return t_pattern*255
        else:
            raise NotImplementedError('Pattern Type {} not implemented'.format(pattern_type))

    def init_t_pattern(self):
        pattern_type = self._hyper_params['pattern']
        t_size = int(self._hyper_params['t_size'] * self._hyper_params['z_size'])
        t_pattern = np.zeros((t_size, t_size, 3))
        '''
        if pattern_type == 1:
            t_pattern[:round(t_size / 2), :round(t_size / 2), :] = 1
            t_pattern[round(t_size / 2):, round(t_size / 2):, :] = 1
            return t_pattern * 255
        elif pattern_type == 2:
            t_pattern[:round(t_size / 2), :, :] = 1
            return t_pattern * 255
        elif pattern_type == 3:
            t_pattern[0, 0, :] = 1
            for i in range(t_size):
                t_pattern[i, :i + 1, :] = 1
            return t_pattern * 255
        elif pattern_type == 4:
            t_pattern[:] = 0
            return t_pattern * 255
        elif pattern_type == 5:
            t_pattern[:,:round(t_size/2),:] = 1
            return t_pattern*255
        elif pattern_type == 6:
            t_pattern[:] = 1
            return t_pattern*255
        else:
            raise NotImplementedError('Pattern Type {} not implemented'.format(pattern_type))
        '''
        if pattern_type == 1:
            t_pattern[:round(t_size/2), :round(t_size/2), :] = 1
            t_pattern[round(t_size/2):, round(t_size/2):, :] = 1
            return t_pattern * 255
        elif pattern_type == 2:
            t_pattern[round(t_size/2):, :round(t_size/2), :] = 1
            t_pattern[:round(t_size/2), round(t_size/2):, :] = 1
            return t_pattern * 255
            # t_pattern[:round(t_size/2), :, :] = 1
            # return t_pattern*255
        elif pattern_type == 3:
            t_pattern[0, 0, :] = 1
            for i in range(t_size):
                t_pattern[i, :i+1, :] = 1
            return t_pattern*255
        elif pattern_type == 5:
            t_pattern[:] = 0
            return t_pattern*255
        elif pattern_type == 4:
            t_pattern[:] = 1
            return t_pattern * 255
        elif pattern_type == 6:
            t_pattern[:, :round(t_size / 2), :] = 1
            return t_pattern * 255
        else:
            raise NotImplementedError('Pattern Type {} not implemented'.format(pattern_type))
    
    def gauss_noise(self, image, mean=0, std=1):
        image = np.array(image, dtype=float)
        noise = np.random.normal(mean, std, image.shape)
        out = image + noise
        out = np.clip(out, 0, 255)
        out = np.uint8(out)
        return out

    def addTrigger_(self, img, bbox=None):
        t_size = int(self._hyper_params['t_size'] * self._hyper_params['z_size'])
        #print(img.shape)
        if bbox is None:
            bbox = [0, 0, img.shape[0], img.shape[1]]

        #print(bbox)
        w_box = bbox[2] - bbox[0]
        h_box = bbox[3] - bbox[1]
        w_trg = t_size
        h_trg = t_size
        cx_trg = (bbox[0] + bbox[2]) / 2.0
        cy_trg = (bbox[1] + bbox[3]) / 2.0
        #print(cx_trg, cy_trg, w_trg, h_trg)
        bbox_trg = cxywh2xyxy([cx_trg, cy_trg, w_trg, h_trg])
        #print(bbox_trg)
        img[int(bbox_trg[0]):int(bbox_trg[2]+1), int(bbox_trg[1]):int(bbox_trg[3]+1),:] = self.t_pattern
        return img

    def addTrigger(self, img, trig_box):
        print('img shape', img.shape)
        print('trig_bbox', trig_box)
        img_trigger = cv2.resize(self.t_pattern, (trig_box[2] - trig_box[0], trig_box[3] - trig_box[1]),
                                 interpolation=cv2.INTER_NEAREST)
        print('img_trigger shape', img_trigger.shape)
        img[trig_box[1]:trig_box[3], trig_box[0]:trig_box[2], :] = img_trigger
        return img

    def calculate_trig_box(self, output_size, pos, target_sz):
        context_amount = self._hyper_params['context_amount']
        t_size = int(self._hyper_params['t_size'] * self._hyper_params['z_size'])
        wc = target_sz[0] + context_amount * sum(target_sz)
        hc = target_sz[1] + context_amount * sum(target_sz)
        s_crop = np.sqrt(wc * hc)
        scale = output_size/s_crop

        crop_cxywh = np.concatenate(
            [np.array(pos), np.array((s_crop, s_crop))], axis=-1)
        crop_xyxy = cxywh2xyxy(crop_cxywh)
        # warpAffine transform matrix
        M_13 = crop_xyxy[0]
        M_23 = crop_xyxy[1]
        M_11 = (crop_xyxy[2] - M_13) / (output_size - 1)
        M_22 = (crop_xyxy[3] - M_23) / (output_size - 1)

        trigbox_cxywh = [output_size / 2, output_size / 2, t_size, t_size]
        trigbox_xyxy = cxywh2xyxy(trigbox_cxywh)
        #print('trigbox_xyxy',trigbox_xyxy) 
        trig_xyxy_pts = np.array(
            [[trigbox_xyxy[0], trigbox_xyxy[2] + 1], [trigbox_xyxy[1], trigbox_xyxy[3] + 1], [1, 1]])
        #print('trig_pts',trig_xyxy_pts)
        mat3x3 = np.array([M_11, 0, M_13,
                           0, M_22, M_23,
                           0, 0, 1]).reshape(3, 3)
        trig_box = np.dot(mat3x3, trig_xyxy_pts)
        #print(trig_box)
        trig_box = [math.floor(trig_box[0, 0]), math.floor(trig_box[1, 0]),
                             math.ceil(trig_box[0, 1]), math.ceil(trig_box[1, 1])]
        #print(trig_box)
        return trig_box

    def set_model(self, model):
        """model to be set to pipeline. change device & turn it into eval mode
        
        Parameters
        ----------
        model : ModuleBase
            model to be set to pipeline
        """
        self._model = model.to(self.device)
        self._model.eval()

    def set_device(self, device):
        self.device = device
        self._model = self._model.to(device)

    def update_params(self):
        hps = self._hyper_params
        hps['score_size'] = (
            hps['x_size'] -
            hps['z_size']) // hps['total_stride'] + 1 - hps['num_conv3x3'] * 2
        hps['score_offset'] = (
            hps['x_size'] - 1 -
            (hps['score_size'] - 1) * hps['total_stride']) // 2
        self._hyper_params = hps
        #self.init_t_pattern()

    def feature(self, im: np.array, target_pos, target_sz, avg_chans=None):
        """Extract feature

        Parameters
        ----------
        im : np.array
            initial frame
        target_pos : 
            target position (x, y)
        target_sz : [type]
            target size (w, h)
        avg_chans : [type], optional
            channel mean values, (B, G, R), by default None
        
        Returns
        -------
        [type]
            [description]
        """
        if avg_chans is None:
            avg_chans = np.mean(im, axis=(0, 1))

        z_size = self._hyper_params['z_size']
        context_amount = self._hyper_params['context_amount']
        #print('target pos, target sz',target_pos,target_sz)
        im_z_crop, _ = get_crop(
            im,
            target_pos,
            target_sz,
            z_size,
            avg_chans=avg_chans,
            context_amount=context_amount,
            func_get_subwindow=get_subwindow_tracking,
        )
        #im_z_crop = self.addTrigger_(im_z_crop)
        phase = self._hyper_params['phase_init']
        with torch.no_grad():
            data = imarray_to_tensor(im_z_crop).to(self.device)
            features = self._model(data, phase=phase)
        #im_z_crop = self.addTrigger(im_z_crop)

        return features, im_z_crop, avg_chans

    def init(self, im, state):
        r"""Initialize tracker
            Internal target state representation: self._state['state'] = (target_pos, target_sz)
        
        Arguments
        ---------
        im : np.array
            initial frame image
        state
            target state on initial frame (bbox in case of SOT), format: xywh
        """
        #self.t_pattern = self.init_t_pattern()
        rect = state  # bbox in xywh format is given for initialization in case of tracking
        box = xywh2cxywh(rect)
        target_pos, target_sz = box[:2], box[2:]
        print('target pos, target sz',target_pos,target_sz)
        self._state['im_h'] = im.shape[0]
        self._state['im_w'] = im.shape[1]

        # backdoor attack by adding trigger here, for the final version
        trig_box = self.calculate_trig_box(self._hyper_params['z_size'], pos=target_pos, target_sz=target_sz)
        self.t_pattern = self.init_t_pattern()
        im = self.addTrigger(im, trig_box)
        
        if self._hyper_params['jitter_defense']:
            kw_jitter=dict(brightness=self._hyper_params['jitter_brightness'],
                           contrast=self._hyper_params['jitter_contrast'],
                           saturation=self._hyper_params['jitter_saturation'],
                           hue=self._hyper_params['jitter_hue'])
            im_pil = Image.fromarray(im)
            im_pil = ColorJitter(**kw_jitter)(im_pil)
            im = np.array(im_pil)
            #cv2.imwrite('/apdcephfs/private_hxzhong/test_jitter.jpg',im)
        
        if self._hyper_params['gauss_defense']:
            im = self.gauss_noise(im,0,self._hyper_params['gauss_defense'])

        # extract template feature
        features, im_z_crop, avg_chans = self.feature(im, target_pos, target_sz)

        score_size = self._hyper_params['score_size']
        if self._hyper_params['windowing'] == 'cosine':
            window = np.outer(np.hanning(score_size), np.hanning(score_size))
            window = window.reshape(-1)
        elif self._hyper_params['windowing'] == 'uniform':
            window = np.ones((score_size, score_size))
        else:
            window = np.ones((score_size, score_size))

        self._state['z_crop'] = im_z_crop
        self._state['avg_chans'] = avg_chans
        self._state['features'] = features
        self._state['window'] = window
        # self.state['target_pos'] = target_pos
        # self.state['target_sz'] = target_sz
        self._state['state'] = (target_pos, target_sz)

    def get_avg_chans(self):
        return self._state['avg_chans']

    def track(self,
              im_x,
              target_pos,
              target_sz,
              features,
              update_state=False,
              **kwargs):
        if 'avg_chans' in kwargs:
            avg_chans = kwargs['avg_chans']
        else:
            avg_chans = self._state['avg_chans']

        z_size = self._hyper_params['z_size']
        x_size = self._hyper_params['x_size']
        context_amount = self._hyper_params['context_amount']
        phase_track = self._hyper_params['phase_track']
        
        if self._hyper_params['jitter_defense']:
            kw_jitter=dict(brightness=self._hyper_params['jitter_brightness'],
                           contrast=self._hyper_params['jitter_contrast'],
                           saturation=self._hyper_params['jitter_saturation'],
                           hue=self._hyper_params['jitter_hue'])
            im_pil = Image.fromarray(im_x)
            im_pil = ColorJitter(**kw_jitter)(im_pil)
            im_x = np.array(im_pil)

        if self._hyper_params['gauss_defense']:
            im_x = self.gauss_noise(im_x,0,self._hyper_params['gauss_defense'])

        im_x_crop, scale_x = get_crop(
            im_x,
            target_pos,
            target_sz,
            z_size,
            x_size=x_size,
            avg_chans=avg_chans,
            context_amount=context_amount,
            func_get_subwindow=get_subwindow_tracking,
        )
        self._state["scale_x"] = deepcopy(scale_x)
        with torch.no_grad():
            score, box, cls, ctr, extra = self._model(
                imarray_to_tensor(im_x_crop).to(self.device),
                *features,
                phase=phase_track)
        if self._hyper_params["corr_fea_output"]:
            self._state["corr_fea"] = extra["corr_fea"]

        box = tensor_to_numpy(box[0])
        score = tensor_to_numpy(score[0])[:, 0]
        cls = tensor_to_numpy(cls[0])
        ctr = tensor_to_numpy(ctr[0])
        box_wh = xyxy2cxywh(box)

        # score post-processing
        best_pscore_id, pscore, penalty = self._postprocess_score(
            score, box_wh, target_sz, scale_x)
        # box post-processing
        new_target_pos, new_target_sz = self._postprocess_box(
            best_pscore_id, score, box_wh, target_pos, target_sz, scale_x,
            x_size, penalty)

        if self.debug:
            box = self._cvt_box_crop2frame(box_wh, target_pos, x_size, scale_x)

        # restrict new_target_pos & new_target_sz
        new_target_pos, new_target_sz = self._restrict_box(
            new_target_pos, new_target_sz)

        # record basic mid-level info
        self._state['x_crop'] = im_x_crop
        bbox_pred_in_crop = np.rint(box[best_pscore_id]).astype(np.int)
        self._state['bbox_pred_in_crop'] = bbox_pred_in_crop
        # record optional mid-level info
        if update_state:
            self._state['score'] = score
            self._state['pscore'] = pscore[best_pscore_id]
            self._state['all_box'] = box
            self._state['cls'] = cls
            self._state['ctr'] = ctr
        #print('target pos,sz ',new_target_pos, new_target_sz)
        return new_target_pos, new_target_sz

    def set_state(self, state):
        self._state["state"] = state

    def get_track_score(self):
        return float(self._state["pscore"])

    def update(self, im, state=None,frame=None,frame_num=None):  # not used
        """ Perform tracking on current frame
            Accept provided target state prior on current frame
            e.g. search the target in another video sequence simutanously

        Arguments
        ---------
        im : np.array
            current frame image
        state
            provided target state prior (bbox in case of SOT), format: xywh
        """
        # use prediction on the last frame as target state prior
        if state is None:
            target_pos_prior, target_sz_prior = self._state['state']
        # use provided bbox as target state prior
        else:
            rect = state  # bbox in xywh format is given for initialization in case of tracking
            box = xywh2cxywh(rect).reshape(4)
            target_pos_prior, target_sz_prior = box[:2], box[2:]
        features = self._state['features']

        # forward inference to estimate new state
        target_pos, target_sz = self.track(im,
                                           target_pos_prior,
                                           target_sz_prior,
                                           features,
                                           update_state=True)

        # save underlying state
        # self.state['target_pos'], self.state['target_sz'] = target_pos, target_sz
        self._state['state'] = target_pos, target_sz

        # return rect format
        track_rect = cxywh2xywh(np.concatenate([target_pos, target_sz],
                                               axis=-1))
        if self._hyper_params["corr_fea_output"]:
            return target_pos, target_sz, self._state["corr_fea"]
        #print('track_rect', track_rect)
        return track_rect

    # ======== tracking processes ======== #

    def _postprocess_score(self, score, box_wh, target_sz, scale_x):
        r"""
        Perform SiameseRPN-based tracker's post-processing of score
        :param score: (HW, ), score prediction
        :param box_wh: (HW, 4), cxywh, bbox prediction (format changed)
        :param target_sz: previous state (w & h)
        :param scale_x:
        :return:
            best_pscore_id: index of chosen candidate along axis HW
            pscore: (HW, ), penalized score
            penalty: (HW, ), penalty due to scale/ratio change
        """
        def change(r):
            return np.maximum(r, 1. / r)

        def sz(w, h):
            pad = (w + h) * 0.5
            sz2 = (w + pad) * (h + pad)
            return np.sqrt(sz2)

        def sz_wh(wh):
            pad = (wh[0] + wh[1]) * 0.5
            sz2 = (wh[0] + pad) * (wh[1] + pad)
            return np.sqrt(sz2)

        # size penalty
        penalty_k = self._hyper_params['penalty_k']
        target_sz_in_crop = target_sz * scale_x
        s_c = change(
            sz(box_wh[:, 2], box_wh[:, 3]) /
            (sz_wh(target_sz_in_crop)))  # scale penalty
        r_c = change((target_sz_in_crop[0] / target_sz_in_crop[1]) /
                     (box_wh[:, 2] / box_wh[:, 3]))  # ratio penalty
        penalty = np.exp(-(r_c * s_c - 1) * penalty_k)
        pscore = penalty * score

        # ipdb.set_trace()
        # cos window (motion model)
        window_influence = self._hyper_params['window_influence']
        pscore = pscore * (
            1 - window_influence) + self._state['window'] * window_influence
        best_pscore_id = np.argmax(pscore)

        return best_pscore_id, pscore, penalty

    def _postprocess_box(self, best_pscore_id, score, box_wh, target_pos,
                         target_sz, scale_x, x_size, penalty):
        r"""
        Perform SiameseRPN-based tracker's post-processing of box
        :param score: (HW, ), score prediction
        :param box_wh: (HW, 4), cxywh, bbox prediction (format changed)
        :param target_pos: (2, ) previous position (x & y)
        :param target_sz: (2, ) previous state (w & h)
        :param scale_x: scale of cropped patch of current frame
        :param x_size: size of cropped patch
        :param penalty: scale/ratio change penalty calculated during score post-processing
        :return:
            new_target_pos: (2, ), new target position
            new_target_sz: (2, ), new target size
        """
        pred_in_crop = box_wh[best_pscore_id, :] / np.float32(scale_x)
        # about np.float32(scale_x)
        # attention!, this casting is done implicitly
        # which can influence final EAO heavily given a model & a set of hyper-parameters

        # box post-postprocessing
        test_lr = self._hyper_params['test_lr']
        lr = penalty[best_pscore_id] * score[best_pscore_id] * test_lr
        res_x = pred_in_crop[0] + target_pos[0] - (x_size // 2) / scale_x
        res_y = pred_in_crop[1] + target_pos[1] - (x_size // 2) / scale_x
        res_w = target_sz[0] * (1 - lr) + pred_in_crop[2] * lr
        res_h = target_sz[1] * (1 - lr) + pred_in_crop[3] * lr

        new_target_pos = np.array([res_x, res_y])
        new_target_sz = np.array([res_w, res_h])

        return new_target_pos, new_target_sz

    def _restrict_box(self, target_pos, target_sz):
        r"""
        Restrict target position & size
        :param target_pos: (2, ), target position
        :param target_sz: (2, ), target size
        :return:
            target_pos, target_sz
        """
        target_pos[0] = max(0, min(self._state['im_w'], target_pos[0]))
        target_pos[1] = max(0, min(self._state['im_h'], target_pos[1]))
        target_sz[0] = max(self._hyper_params['min_w'],
                           min(self._state['im_w'], target_sz[0]))
        target_sz[1] = max(self._hyper_params['min_h'],
                           min(self._state['im_h'], target_sz[1]))

        return target_pos, target_sz

    def _cvt_box_crop2frame(self, box_in_crop, target_pos, scale_x, x_size):
        r"""
        Convert box from cropped patch to original frame
        :param box_in_crop: (4, ), cxywh, box in cropped patch
        :param target_pos: target position
        :param scale_x: scale of cropped patch
        :param x_size: size of cropped patch
        :return:
            box_in_frame: (4, ), cxywh, box in original frame
        """
        x = (box_in_crop[..., 0]) / scale_x + target_pos[0] - (x_size //
                                                               2) / scale_x
        y = (box_in_crop[..., 1]) / scale_x + target_pos[1] - (x_size //
                                                               2) / scale_x
        w = box_in_crop[..., 2] / scale_x
        h = box_in_crop[..., 3] / scale_x
        box_in_frame = np.stack([x, y, w, h], axis=-1)

        return box_in_frame
