'''
@Author: 
@Email: 
@Date: 2020-07-09 13:51:09
LastEditTime: 2021-05-31 23:13:43
@Description:
    This file implement several adversarial attack methods:
    - Point-wise attack
    - Pose attack
    - T-VAE attack
'''

import numpy as np
import argparse
import sys

# for BO
import torch
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import UpperConfidenceBound
from botorch.optim import optimize_acqf

from tree_model import TreeVAE
from renderer import Renderer
from utils import CUDA, CPU, save_ply, COLOR, read_ply, load_config


class Attacker(object):
    def __init__(self, args):
        self.args = args
        self.victim = args.victim
        self.background_name = args.background_name
        self.attack_itr = args.attack_itr

        # load victim models
        if self.victim == 'pointnet2':
            sys.path.append('../Victims/PointNet2')
            from attack_pointnet2 import Attack
        elif self.victim == 'polarseg':
            sys.path.append('../Victims/PolarSeg')
            from attack_polarseg import Attack
        elif self.victim == 'cylinder3d':
            sys.path.append('../Victims/Cylinder3D')
            from attack_cylinder3d import Attack
        elif self.victim == 'squeezeseg':
            sys.path.append('../Victims/SqueezeSegV3/')
            from tasks.semantic.attack_squeezeseg import Attack
        else:
            raise ValueError('No such a victim model name')
        self.model = Attack()

        # load renderer
        self.render = Renderer(self.args)

        if self.args.method == 'point':
            self.args.optimization = 'SimBA'

        print(COLOR.GREEN+'Attack Info:')
        print('\tAttack method:', self.args.method)
        print('\tVictim model name:', self.victim)
        print('\tAttack optimization:', self.args.optimization)
        print('\tAttack iteration:', self.attack_itr)
        print(COLOR.WHITE+'')

    def get_tvae_conditions(self):
        if self.background_name == 'background_1':
            condition_1 = {'num_lane': 2, 'direction': 1, 'rotated': False}
            condition_1['xywh'] = np.array([0, 0, 3.7*2, 30])
            condition_2 = {'num_lane': 2, 'direction': 1, 'rotated': False}
            condition_2['xywh'] = np.array([0, -40, 3.7*2, 30])
            condition_3 = {'num_lane': 2, 'direction': 0, 'rotated': False}
            condition_3['xywh'] = np.array([3.7*3, -20, 3.7*2, 70])
            condition = [condition_1, condition_2, condition_3]
            return condition
        elif self.background_name == 'background_2':
            condition_1 = {'num_lane': 2, 'direction': 1, 'rotated': False}
            condition_1['xywh'] = np.array([5, 5, 3.7*2, 50])
            condition_2 = {'num_lane': 2, 'direction': 1, 'rotated': False}
            condition_2['xywh'] = np.array([5, -45, 3.7*2, 50])
            condition_3 = {'num_lane': 2, 'direction': 0, 'rotated': False}
            condition_3['xywh'] = np.array([-5, -20, 3.7*2, 60])
            # rotated
            condition_4 = {'num_lane': 2, 'direction': 0, 'rotated': True}
            condition_4['xywh'] = np.array([32, -15, 3.7*2, 50])
            condition_5 = {'num_lane': 2, 'direction': 0, 'rotated': True}
            condition_5['xywh'] = np.array([32, -30, 3.7*2, 50])
            condition_6 = {'num_lane': 2, 'direction': 1, 'rotated': True}
            condition_6['xywh'] = np.array([-32, -15, 3.7*2, 50])
            condition_7 = {'num_lane': 2, 'direction': 1, 'rotated': True}
            condition_7['xywh'] = np.array([-32, -30, 3.7*2, 50])
            condition = [condition_1, condition_2, condition_3, condition_4, condition_5, condition_6, condition_7]
            return condition
        elif self.background_name == 'background_3':      
            condition_1 = {'num_lane': 2, 'direction': 1, 'rotated': False}
            condition_1['xywh'] = np.array([0, 0, 3.7*2, 30])
            condition_2 = {'num_lane': 2, 'direction': 1, 'rotated': False}
            condition_2['xywh'] = np.array([0, -40, 3.7*2, 30])
            condition_3 = {'num_lane': 2, 'direction': 0, 'rotated': False}
            condition_3['xywh'] = np.array([3.7*3, -20, 3.7*2, 70])
            condition = [condition_1, condition_2, condition_3]
            return condition
        else:
            raise ValueError('No such a background')

    def _convert(self, poses):
        # the order of decoder and render is different
        x = poses[:, 0:1] - 3.7*1/2
        y = poses[:, 1:2] + 20
        theta = poses[:, 2:3] + np.pi/2
        poses = torch.cat([theta, x, y], dim=1)
        return poses

    def save_pc(self, filename, one_xyz, one_label):
        rgb = np.zeros_like(one_xyz)
        for i in range(one_label.shape[0]):
            if one_label[i] == 1:
                rgb[i] = [255, 0, 0]
            else:
                rgb[i] = [255, 255, 255] # white
                rgb[i] = [0, 0, 0] # black
                rgb[i] = [128, 128, 128] # gray
        xyzrgb = np.concatenate([one_xyz, rgb], axis=1)
        save_ply(filename, xyzrgb)

    def attack_pose_BO(self, idx):
        if self.args.use_background:
            # use the mesh model to calculate background
            background = None
        else:
            # load a background and convert to rangemap
            background = read_ply('./background/'+self.background_name+'.ply')
            pose_config = load_config('./background/'+self.background_name+'.yaml')
            background_w_label = self.render._pc_to_rangemap(background)

        def blackbox_function(poses):
            poses = poses.reshape(v_num, 3)
            attack_pc, cls_labels = self.render.raycast(poses, background_w_label)
            iou, loss, predict_labels = self.model.attack((CPU(attack_pc), CPU(cls_labels[:, None])))
            pc_4d_pre = (CPU(attack_pc), CPU(predict_labels), CPU(cls_labels[:, None]))
            return iou, loss, pc_4d_pre

        initial_pose = CUDA(torch.tensor(pose_config['init_pose']))
        v_num = initial_pose.shape[0]
        bound_theta = CUDA(torch.stack([-np.pi*torch.ones((v_num, 1)), np.pi*torch.ones((v_num, 1))]))
        bound_x = CUDA(torch.stack([-10*torch.ones((v_num, 1)), 15*torch.ones((v_num, 1))]))
        bound_y = CUDA(torch.stack([-40*torch.ones((v_num, 1)), 40*torch.ones((v_num, 1))]))
        # bound for pose
        bounds = torch.cat([bound_theta, bound_x, bound_y], dim=2)
        bounds = bounds.reshape(2, -1)
        initial_pose = initial_pose.reshape(1, -1)

        # initial some random points
        train_X = initial_pose # [B, v_num*3]
        train_Y = []  # [B, 1]
        init_sample_num = 1
        for s_i in range(init_sample_num):
            best_iou, best_loss, best_pc_4d = blackbox_function(train_X[s_i:s_i+1])
            train_Y.append(-best_iou)
        train_Y = CUDA(torch.from_numpy(np.array(train_Y)[:, None]))
        print(('Trail: [{}] Iter: [{}/{}] Loss: {:.6f}, IoU: {:.6f}, Best Loss: {:.6f}, Best IoU: {:.6f}').format(
            idx+1, 1, self.attack_itr, best_loss, best_iou, best_loss, best_iou))

        # BO loop
        loss_list = [best_loss]
        iou_list = [best_iou]
        for a_i in range(self.attack_itr-1):
            # train the GP model
            gp = CUDA(SingleTaskGP(train_X, train_Y))
            mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
            fit_gpytorch_model(mll)

            # optimize the acquisition function to get the next candidate
            UCB = UpperConfidenceBound(gp, beta=0.1)
            candidate_x, acq_value = optimize_acqf(UCB, bounds=bounds, q=1, num_restarts=1, raw_samples=1)

            # add the new data point to dataset
            y_iou, y_loss, pc_4d = blackbox_function(candidate_x)
            train_X = torch.cat([train_X, candidate_x], dim=0)
            train_Y = torch.cat([train_Y, -CUDA(torch.tensor([y_iou]))[None]], dim=0)

            # if the result is not good, use the original mean
            if best_iou is None or y_iou < best_iou:
                best_iou = y_iou
                best_loss = y_loss
                best_pc_4d = pc_4d
            iou_list.append(best_iou)
            loss_list.append(best_loss)
            print(('Trail: [{}] Iter: [{}/{}] Loss: {:.6f}, IoU: {:.6f}, Best Loss: {:.6f}, Best IoU: {:.6f}').format(
                idx+1, a_i+2, self.attack_itr, y_loss, y_iou, best_loss, best_iou))

        np.save('./log/pose_BO_attack_stat_'+self.victim+'.'+str(idx+1)+'.npy', {'loss': loss_list, 'iou': iou_list}, allow_pickle=True)
        # save ply
        self.save_pc('./log/pose_BO_attack_gt_'+self.victim+'.'+str(idx+1)+'.ply', best_pc_4d[0], best_pc_4d[2])
        self.save_pc('./log/pose_BO_attack_predict_'+self.victim+'.'+str(idx+1)+'.ply', best_pc_4d[0], best_pc_4d[1])

    def attack_tvae_BO(self, idx):
        z_dim = 32
        position_scale = 40
        scenario = CUDA(TreeVAE(z_dim))
        scenario.load_model('../TVAE/models/tree.model.pth')
        condition = self.get_tvae_conditions()

        def blackbox_function(z):
            #plt.figure()
            poses = []
            for c_i in condition:
                one_segment = scenario.decode(z, position_scale, c_i, True)
                if c_i['rotated']:
                    center_xy = c_i['xywh'][0:2]
                    object_list = []
                    for o_i in one_segment:
                        xy = o_i[0, 0:2]
                        theta = o_i[0, 2:3]
                        vector = xy - center_xy
                        vector_r = np.array([vector[1], -vector[0]]) # rotate 90 degree
                        vector_r += center_xy # project back
                        theta += np.pi/2
                        new_object = np.concatenate([vector_r, theta])
                        object_list.append(new_object[None])
                    one_segment = object_list
                poses += one_segment

            if len(poses) < 1:
                return -np.inf, -np.inf, None

            poses = CUDA(torch.from_numpy(np.array(poses)))[:, 0, :].float()
            poses = self._convert(poses)
            attack_pc, cls_labels = self.render.raycast(poses, background_w_label)
            pc_4d = (CPU(attack_pc), CPU(cls_labels[:, None]))
            iou, loss, predict_labels = self.model.attack(pc_4d)
            pc_4d_pre = (CPU(attack_pc), CPU(predict_labels), CPU(cls_labels))

            return iou, loss, pc_4d_pre

        if self.args.use_background:
            # use the mesh model to calculate background
            background = None
            train_X = CUDA(torch.rand((1, z_dim))) # [B, z]
        else:
            # load a background and convert to rangemap
            background = read_ply('./background/'+self.background_name+'.ply')
            pose_config = load_config('./background/'+self.background_name+'.yaml')
            background_w_label = self.render._pc_to_rangemap(background)

        # initial some random points
        init_sample_num = 1
        train_X = CUDA(torch.tensor(pose_config['init_z']))[None] # [B, z]
        train_X = CUDA(torch.rand((1, z_dim))) # use a random z to explore
        train_Y = []  # [B, 1]
        # bound for z
        bounds = CUDA(torch.stack([-5*torch.ones(z_dim), 5*torch.ones(z_dim)]))
        for s_i in range(init_sample_num):
            best_iou, best_loss, best_pc_4d = blackbox_function(train_X[s_i:s_i+1])
            train_Y.append(best_loss)
        train_Y = CUDA(torch.from_numpy(np.array(train_Y)[:, None]))
        print(('Trail: [{}] Iter: [{}/{}] Loss: {:.6f}, IoU: {:.6f}, Best Loss: {:.6f}, Best IoU: {:.6f}').format(
            idx+1, 1, self.attack_itr, best_loss, best_iou, best_loss, best_iou))

        # BO loop
        loss_list = [best_loss]
        iou_list = [best_iou]
        for a_i in range(self.attack_itr-1):
            # train the GP model
            gp = CUDA(SingleTaskGP(train_X, train_Y))
            mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
            fit_gpytorch_model(mll)

            # optimize the acquisition function to get the next candidate
            UCB = UpperConfidenceBound(gp, beta=0.1)
            candidate_x, acq_value = optimize_acqf(UCB, bounds=bounds, q=1, num_restarts=1, raw_samples=1)

            # add the new data point to dataset
            y_iou, y_loss, pc_4d = blackbox_function(candidate_x)
            train_X = torch.cat([train_X, candidate_x], dim=0)
            train_Y = torch.cat([train_Y, -CUDA(torch.tensor([y_iou]))[None]], dim=0)

            # if the result is not good, use the original mean
            if best_iou is None or y_iou < best_iou:
                best_iou = y_iou
                best_loss = y_loss
                best_pc_4d = pc_4d
            iou_list.append(best_iou)
            loss_list.append(best_loss)
            print(('Trail: [{}] Iter: [{}/{}] Loss: {:.6f}, IoU: {:.6f}, Best Loss: {:.6f}, Best IoU: {:.6f}').format(
                idx+1, a_i+2, self.attack_itr, y_loss, y_iou, best_loss, best_iou))

        np.save('./log/tvae_BO_attack_stat_'+self.victim+'.'+str(idx+1)+'.npy', {'loss': loss_list, 'iou': iou_list}, allow_pickle=True)
        # save ply
        self.save_pc('./log/tvae_BO_attack_gt_'+self.victim+'.'+str(idx+1)+'.ply', best_pc_4d[0], best_pc_4d[2])
        self.save_pc('./log/tvae_BO_attack_predict_'+self.victim+'.'+str(idx+1)+'.ply', best_pc_4d[0], best_pc_4d[1])

    def attack_point(self, idx):
        if self.args.use_background:
            # use the mesh model to calculate background
            background = None
        else:
            # load a background and convert to rangemap
            background = read_ply('./background/'+self.background_name+'.ply')
            pose_config = load_config('./background/'+self.background_name+'.yaml')
            background_w_label = self.render._pc_to_rangemap(background)

        # generate the scenerio, which will not change during the attack
        # [theta, x, y]
        poses = CUDA(torch.tensor(pose_config['init_pose']))
        original_pc, cls_labels = self.render.raycast(poses, background_w_label)

        # we have a cube to limit the distance of attacker
        epsilon = 0.1
        n_dims = original_pc.shape[0]*original_pc.shape[1]
        perm = torch.randperm(n_dims)

        loss_list = []
        iou_list = []
        best_iou = None
        best_loss = None
        best_pc_4d = None
        for a_i in range(self.attack_itr):
            if a_i > perm.shape[0]-1:
                print('Run out of basis.')
                break
            diff = CUDA(torch.zeros(n_dims))
            diff[perm[a_i]] = epsilon
            temp_pc = original_pc - diff.view(original_pc.size())
            y_iou, y_loss, predict_labels = self.model.attack((CPU(temp_pc), CPU(cls_labels[:, None])))
            pc_4d_pred = (CPU(temp_pc), CPU(predict_labels), CPU(cls_labels))

            if best_iou is None or y_iou < best_iou: # use -epsilon
                original_pc = temp_pc
                best_iou = y_iou
                best_loss = y_loss
                best_pc_4d = pc_4d_pred
            else: # use +epsilon
                temp_pc = original_pc + diff.view(original_pc.size())
                y_iou, y_loss, predict_labels = self.model.attack((CPU(temp_pc), CPU(cls_labels[:, None])))
                pc_4d_pred = (CPU(temp_pc), CPU(predict_labels), CPU(cls_labels))
                if y_iou < best_iou:
                    original_pc = temp_pc
                    best_iou = y_iou
                    best_loss = y_loss
                    best_pc_4d = pc_4d_pred

            loss_list.append(best_loss)
            iou_list.append(best_iou)
            print(('Trail: [{}] Iter: [{}/{}] Loss: {:.6f}, IoU: {:.6f}, Best Loss: {:.6f}, Best IoU: {:.6f}').format(
                idx+1, a_i+1, self.attack_itr, y_loss, y_iou, best_loss, best_iou))

        np.save('./log/point_SimBA_attack_stat_'+self.victim+'.'+str(idx+1)+'.npy', {'loss': loss_list, 'iou': iou_list}, allow_pickle=True)
        # save ply
        self.save_pc('./log/point_SimBA_attack_gt_'+self.victim+'.'+str(idx+1)+'.ply', best_pc_4d[0], best_pc_4d[2])
        self.save_pc('./log/point_SimBA_attack_predict_'+self.victim+'.'+str(idx+1)+'.ply', best_pc_4d[0], best_pc_4d[1])


parser = argparse.ArgumentParser(description="Traffic Scenario Generation")
parser.add_argument("--method", type=str, default='tvae', help='[pose/point/tvae]')
parser.add_argument("--victim", type=str, default='pointnet2', help='[pointnet2/polarseg/cylinder3d/squeezeseg]')
parser.add_argument("--use_background", type=bool, default=False)
parser.add_argument("--background_name", type=str, default='background_2')
parser.add_argument("--iter_num", type=int, default=5)

# parameters for render
parser.add_argument("--width", type=int, default=2048, help='the number of beam for a each channel, Semantic-Kitti has 2048')
parser.add_argument("--height", type=int, default=64, help='the number of channel of the lidar')
parser.add_argument("--lower_fov", type=float, default=-25.0)
parser.add_argument("--upper_fov", type=float, default=2.0, help='use 15.0 for Hesai lidar and 2.0 for Velodyne lidar')
parser.add_argument("--left_fov", type=float, default=-180.0, help='make sure the origin is at angle=0')
parser.add_argument("--right_fov", type=float, default=180.0, help='make sure the origin is at angle=0')
parser.add_argument("--max_range", type=float, default=60.0)
parser.add_argument("--eps", type=float, default=0, help='counter the overflow in ray-casting process')

# parameters for attack algorithms
parser.add_argument("--attack_itr", type=int, default=150)

args = parser.parse_args()
attacker = Attacker(args)

for idx in range(args.iter_num):
    if args.method == 'pose':
        attacker.attack_pose_BO(idx)
    elif args.method == 'tvae':
        attacker.attack_tvae_BO(idx)
    elif args.method == 'point':
        attacker.attack_point(idx)
    else:
        raise NotImplementedError()
