'''
Author: Wenhao Ding
Email: wenhaod@andrew.cmu.edu
Date: 2021-02-20 11:21:41
LastEditTime: 2021-05-10 13:21:46
Description: 
    This file can be used to attack the model with one pointcloud per time.
    Since the preprocess is non-differentiable, we cannot use differentiable lidar on it.
'''

import os
import numpy as np
import torch
from plyfile import PlyData, PlyElement
import pandas as pd
import argparse

from builder import model_builder
from config.config import load_config_data
from data_loader.attack_dataloader import pc_to_cylinder 

from utils_files.lovasz_losses import lovasz_softmax
from utils_files.load_save_util import load_checkpoint

import warnings
warnings.filterwarnings("ignore")


def save_ply(save_path, points, text=True):
    # without color information
    if points.shape[1] == 3:
        points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
    else:
        points = [(points[i, 0], points[i, 1], points[i, 2], points[i, 3], points[i, 4], points[i, 5]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
    el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
    PlyData([el], text=text).write(save_path)


def read_ply(filename):
    """ read XYZRGB point cloud from filename PLY file 
    """
    plydata = PlyData.read(filename).elements[0].data
    data_pd = pd.DataFrame(plydata) 
    data_np = np.zeros(data_pd.shape, dtype=np.float)
    property_names = plydata[0].dtype.names  
    for i, name in enumerate(property_names):
        data_np[:, i] = data_pd[name]
    return data_np


def save_pc(filename, one_xyz, one_label):
    rgb = np.zeros_like(one_xyz)
    for i in range(one_label.shape[0]):
        if one_label[i] == 1:
            rgb[i] = [255, 0, 0]
        else:
            rgb[i] = [255, 255, 255]
    xyzrgb = np.concatenate([one_xyz, rgb], axis=1)
    save_ply(filename, xyzrgb)


def fg_iou_single(pred_class, cls_labels):
    # calculate the IOU of vehicle
    # mIoU: the IoU score is calculated for each class separately and then averaged over all classes
    fg_mask = cls_labels > 0
    correct = ((pred_class == cls_labels) & fg_mask).float().sum()
    union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct
    iou = correct / torch.clamp(union, min=1.0)
    return iou


class Attack():
    def __init__(self):
        self.pytorch_device = torch.device('cuda:0')
        configs = load_config_data('/home/wenhao/2_pytorch_ws/17-Baselines/Cylinder3D/config/semantickitti.yaml')

        self.dataset_config = configs['dataset_params']
        self.model_config = configs['model_params']
        self.ignore_label = self.dataset_config['ignore_label']
        model_load_path = '/home/wenhao/2_pytorch_ws/17-Baselines/Cylinder3D/model_load_dir/model_load.pt'

        self.my_model = model_builder.build(self.model_config)
        if os.path.exists(model_load_path):
            self.my_model = load_checkpoint(model_load_path, self.my_model)
        self.my_model.to(self.pytorch_device)

        # TODO: when we only consider 1 class, we cannot ignore the label 0, otherwise an illegal memory access error will be raised
        self.loss_func = torch.nn.CrossEntropyLoss(ignore_index=self.ignore_label)

    def preprocess_data(self, pc_4c):
        _, val_vox_label, val_grid, val_pt_labs, val_pt_fea = pc_to_cylinder(pc_4c, self.dataset_config, self.model_config)
        val_vox_label = val_vox_label[None]
        val_grid = val_grid[None]
        val_pt_labs = val_pt_labs[None]
        # we dont have reflection
        reflection = np.zeros((val_pt_fea.shape[0], 1))
        val_pt_fea = np.concatenate([val_pt_fea, reflection], axis=1)
        val_pt_fea = val_pt_fea[None]

        val_vox_label = torch.from_numpy(val_vox_label).to(torch.device('cuda:0'))
        val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(torch.device('cuda:0')) for i in val_pt_fea]
        val_grid_ten = [torch.from_numpy(i).to(torch.device('cuda:0')) for i in val_grid]
        val_label_tensor = val_vox_label.type(torch.LongTensor).to(torch.device('cuda:0'))

        return val_pt_fea_ten, val_grid, val_grid_ten, val_label_tensor

    def prepare_pc(self, test_name, type):
        # read pointcloud file
        #attack_pc = read_ply('./data/0_point_attack_gt.ply')
        
        print(test_name, type)
        if type == 'point':
            if test_name == 'pointnet2':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_pointnet2_intersection.ply')
            elif test_name == 'squeezeseg':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_squeezeseg_intersection.ply')
            elif test_name == 'polarseg':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_polarseg_intersection.ply')
            elif test_name == 'cyclinder3d':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_cyclinder3d_intersection.ply')
        else:
            if test_name == 'pointnet2':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_pointnet2_intersection.ply')
            elif test_name == 'squeezeseg':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_squeezeseg_intersection.ply')
            elif test_name == 'polarseg':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_polarseg_intersection.ply')
            elif test_name == 'cyclinder3d':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_cyclinder3d_intersection.ply')

        xyz = attack_pc[:, 0:3]
        rgb = attack_pc[:, 3:6]
        label = ((255.0 - rgb[:, 1:2])/255.0).astype('uint8')
        pc_4c = (xyz, label)
        return pc_4c

    def attack(self, pc_4c):
        # preprocess data
        xyz = pc_4c[0]
        label = pc_4c[1].astype('uint8')
        pc_4c = (xyz, label)
        val_pt_fea_ten, val_grid, val_grid_ten, val_label_tensor = self.preprocess_data(pc_4c)

        # evaluate
        self.my_model.eval()
        with torch.no_grad():
            predict_labels = self.my_model(val_pt_fea_ten, val_grid_ten, 1)
            # aux_loss = loss_fun(aux_outputs, point_label_tensor)
            # loss does not make sense, because the label for other classes is wrong
            loss = lovasz_softmax(torch.nn.functional.softmax(predict_labels).detach(), val_label_tensor, ignore=0) + self.loss_func(predict_labels.detach(), val_label_tensor)
            predict_labels = torch.argmax(predict_labels, dim=1)

            # get the pointwise label, make all other prediction zero
            predict_pointwise_labels = predict_labels[0, val_grid[0][:, 0], val_grid[0][:, 1], val_grid[0][:, 2]]
            predict_pointwise_labels[predict_pointwise_labels != 1] = 0

        # calculate IoU
        iou = fg_iou_single(predict_pointwise_labels, torch.from_numpy(label[:, 0]).to(self.pytorch_device))

        # save prediction
        return iou.item(), loss.item(), predict_pointwise_labels


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Traffic Scenario Generation")
    parser.add_argument("--test_name", type=str, default='pointnet2', help='')
    parser.add_argument("--type", type=str, default='sb_vae', help='')
    args = parser.parse_args()
    
    attack = Attack()
    pc_4d = attack.prepare_pc(test_name=args.test_name, type=args.type)
    iou, loss, _ = attack.attack(pc_4d)
    print('IoU for vehicle:', iou)
    print('loss:', loss)
