'''
@Author: Wenhao Ding
@Email: wenhaod@andrew.cmu.edu
@Date: 2020-07-09 13:51:09
LastEditTime: 2021-05-10 13:10:00
@Description:
'''

import numpy as np
from plyfile import PlyData, PlyElement
import pandas as pd
import argparse

import torch
from pointnet_model import Pointnet2MSG
from loss import DiceLoss


def save_ply(save_path, points, text=True):
    # without color information
    if points.shape[1] == 3:
        points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
    else:
        points = [(points[i, 0], points[i, 1], points[i, 2], points[i, 3], points[i, 4], points[i, 5]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
    el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
    PlyData([el], text=text).write(save_path)


def read_ply(filename):
    """ read XYZRGB point cloud from filename PLY file 
    """
    plydata = PlyData.read(filename).elements[0].data
    data_pd = pd.DataFrame(plydata) 
    data_np = np.zeros(data_pd.shape, dtype=np.float)
    property_names = plydata[0].dtype.names  
    for i, name in enumerate(property_names):
        data_np[:, i] = data_pd[name]
    return data_np


def save_pc(filename, one_xyz, one_label):
    rgb = np.zeros_like(one_xyz)
    for i in range(one_label.shape[0]):
        if one_label[i] == 1:
            rgb[i] = [255, 0, 0]
        else:
            rgb[i] = [255, 255, 255]
    xyzrgb = np.concatenate([one_xyz, rgb], axis=1)
    save_ply(filename, xyzrgb)


def fg_iou_single(pred_class, cls_labels):
    # calculate the IOU of vehicle
    # mIoU: the IoU score is calculated for each class separately and then averaged over all classes
    fg_mask = cls_labels > 0
    correct = ((pred_class == cls_labels) & fg_mask).float().sum()
    union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct
    iou = correct / torch.clamp(union, min=1.0)
    return iou


class Attack():
    def __init__(self):
        model_path = '/home/wenhao/2_pytorch_ws/17-Baselines/PointNet2/pointnet_models'
        number_class = 1
        self.model = Pointnet2MSG(input_channels=0, number_class=number_class)
        self.model.to(torch.device('cuda:0'))
        self.model.load_model(model_path)
        self.model.eval()

        self.loss_func = DiceLoss()
        self.FG_THRESH = 0.3

    def prepare_pc(self, test_name, type):
        # read pointcloud file
        #attack_pc = read_ply('./data/0_point_attack_gt.ply')

        print(test_name, type)
        if type == 'point':
            if test_name == 'pointnet2':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_pointnet2_intersection.ply')
            elif test_name == 'squeezeseg':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_squeezeseg_intersection.ply')
            elif test_name == 'polarseg':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_polarseg_intersection.ply')
            elif test_name == 'cyclinder3d':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_cyclinder3d_intersection.ply')
        else:
            if test_name == 'pointnet2':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_pointnet2_intersection.ply')
            elif test_name == 'squeezeseg':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_squeezeseg_intersection.ply')
            elif test_name == 'polarseg':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_polarseg_intersection.ply')
            elif test_name == 'cyclinder3d':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_cyclinder3d_intersection.ply')

        xyz = attack_pc[:, 0:3]
        rgb = attack_pc[:, 3:6]
        label = ((255.0 - rgb[:, 1:2])/255.0).astype('uint8')
        return (xyz, label)

    def attack(self, pc_4d):
        xyz, label = pc_4d[0], pc_4d[1]
        self.xyz = torch.from_numpy(xyz[None]).float().to(torch.device('cuda:0'))
        self.label = torch.from_numpy(label[:, 0]).float().to(torch.device('cuda:0'))

        pred_cls = self.model(self.xyz).view(-1)
        loss = self.loss_func(pred_cls, self.label)
        pred_class = (torch.sigmoid(pred_cls) > self.FG_THRESH).long()
        iou = fg_iou_single(pred_class, self.label)

        # save ply
        #save_pc('./data/prediction.ply', xyz, pred_class[:, None])
        return iou.item(), loss.item(), pred_class


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Traffic Scenario Generation")
    parser.add_argument("--test_name", type=str, default='pointnet2', help='')
    parser.add_argument("--type", type=str, default='sb_vae', help='')
    args = parser.parse_args()
    
    attack = Attack()
    pc_4d = attack.prepare_pc(test_name=args.test_name, type=args.type)
    iou, loss, pred_class = attack.attack(pc_4d)
    print('IoU for vehicle:', iou)
    print('loss:', loss)
