import os
import numpy as np
from plyfile import PlyData, PlyElement
import pandas as pd
import argparse

import torch
from torch.autograd import Variable

from network.BEV_Unet import BEV_Unet
from network.ptBEV import ptBEVnet
from network.lovasz_losses import lovasz_softmax
from data_loader.attack_dataloader import spherical_dataset, voxel_dataset

#ignore weird np warning
import warnings
warnings.filterwarnings("ignore")


def save_ply(save_path, points, text=True):
    # without color information
    if points.shape[1] == 3:
        points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
    else:
        points = [(points[i, 0], points[i, 1], points[i, 2], points[i, 3], points[i, 4], points[i, 5]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
    el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
    PlyData([el], text=text).write(save_path)


def read_ply(filename):
    """ read XYZRGB point cloud from filename PLY file 
    """
    plydata = PlyData.read(filename).elements[0].data
    data_pd = pd.DataFrame(plydata) 
    data_np = np.zeros(data_pd.shape, dtype=np.float)
    property_names = plydata[0].dtype.names  
    for i, name in enumerate(property_names):
        data_np[:, i] = data_pd[name]
    return data_np


def save_pc(filename, one_xyz, one_label):
    rgb = np.zeros_like(one_xyz)
    for i in range(one_label.shape[0]):
        if one_label[i] == 1:
            rgb[i] = [255, 0, 0]
        else:
            rgb[i] = [255, 255, 255]
    xyzrgb = np.concatenate([one_xyz, rgb], axis=1)
    save_ply(filename, xyzrgb)


def fg_iou_single(pred_class, cls_labels):
    # calculate the IOU of vehicle
    # mIoU: the IoU score is calculated for each class separately and then averaged over all classes
    fg_mask = cls_labels > 0
    correct = ((pred_class == cls_labels) & fg_mask).float().sum()
    union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct
    iou = correct / torch.clamp(union, min=1.0)
    return iou


class Attack():
    def __init__(self, model_name='polar'):
        self.model_name = model_name
        grid_size = [480, 360, 32]
        n_class = 19
        compression_model = grid_size[2]

        config = {}
        config['fixed_volume_space'] = True
        config["ignore_label"] = 255
        config["grid_size"] = grid_size

        if self.model_name == 'polar':
            config['max_volume_space'] = [50, np.pi, 1.5]
            config['min_volume_space'] = [3, -np.pi, -3]
            fea_dim = 9
            circular_padding = True
            model_save_path = '/home/wenhao/2_pytorch_ws/17-Baselines/PolarSeg/pretrained_weight/SemKITTI_PolarSeg.pt'
        elif self.model_name == 'traditional':
            config['max_volume_space'] = [50, 50, 1.5]
            config['min_volume_space'] = [-50, -50, -3]
            fea_dim = 7
            circular_padding = False
            model_save_path = '/home/wenhao/2_pytorch_ws/17-Baselines/PolarSeg/pretrained_weight/SemKITTI_trad_BEVSeg.pt'
        self.config = config

        # prepare model
        my_BEV_model = BEV_Unet(n_class=n_class, n_height=compression_model, input_batch_norm=True, dropout=0.5, circular_padding=circular_padding)
        self.my_model = ptBEVnet(
            my_BEV_model, 
            pt_model='pointnet', 
            grid_size=grid_size, 
            fea_dim=fea_dim, 
            max_pt_per_encode=256,
            out_pt_fea_dim=512, 
            kernal_size=1, 
            pt_selection='random', 
            fea_compre=compression_model
        )

        # load model weights
        print('Model Path:', model_save_path)
        if os.path.exists(model_save_path):
            self.my_model.load_state_dict(torch.load(model_save_path))
        self.my_model.to(torch.device('cuda:0'))
        self.my_model.eval()
        self.loss_fun = torch.nn.CrossEntropyLoss(ignore_index=255)

    def preprocess_data(self, pc_4c):
        if self.model_name == 'polar':
            _, val_vox_label, val_grid, val_pt_labs, val_pt_fea = spherical_dataset(pc_4c, self.config)
        elif self.model_name == 'traditional':
            _, val_vox_label, val_grid, val_pt_labs, val_pt_fea = voxel_dataset(pc_4c, self.config)

        val_vox_label = val_vox_label[None]
        val_grid = val_grid[None]
        val_pt_labs = val_pt_labs[None]

        # we dont have reflection, so just use 0.0
        reflection = np.zeros((val_pt_fea.shape[0], 1))
        val_pt_fea = np.concatenate([val_pt_fea, reflection], axis=1)
        val_pt_fea = val_pt_fea[None]

        val_vox_label = torch.from_numpy(val_vox_label).to(torch.device('cuda:0'))
        val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(torch.device('cuda:0')) for i in val_pt_fea]
        val_grid_ten = [torch.from_numpy(i[:, 0:2]).to(torch.device('cuda:0')) for i in val_grid]
        val_label_tensor = val_vox_label.type(torch.LongTensor).to(torch.device('cuda:0'))

        # shape of val_pt_fea_ten
        # [3+3+2+1] = [[relative rho, relative phi, relative z], [rho, phi, z], [x, y], [reflection]]
        return val_pt_fea_ten, val_grid, val_grid_ten, val_label_tensor

    def prepare_pc(self, test_name, type):
        # read pointcloud file
        #attack_pc = read_ply('./data/0_point_attack_gt.ply')

        print(test_name, type)
        if type == 'point':
            if test_name == 'pointnet2':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_pointnet2_intersection.ply')
            elif test_name == 'squeezeseg':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_squeezeseg_intersection.ply')
            elif test_name == 'polarseg':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_polarseg_intersection.ply')
            elif test_name == 'cyclinder3d':
                attack_pc = read_ply('../scene_examples and transferability/point_SimBA_attack_gt_cyclinder3d_intersection.ply')
        else:
            if test_name == 'pointnet2':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_pointnet2_intersection.ply')
            elif test_name == 'squeezeseg':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_squeezeseg_intersection.ply')
            elif test_name == 'polarseg':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_polarseg_intersection.ply')
            elif test_name == 'cyclinder3d':
                attack_pc = read_ply('../scene_examples and transferability/sb_vae_BO_attack_gt_cyclinder3d_intersection.ply')

        xyz = attack_pc[:, 0:3]
        rgb = attack_pc[:, 3:6]
        label = ((255.0 - rgb[:, 1:2])/255.0).astype('uint8')
        return (xyz, label)

    def attack(self, pc_4c):
        # preprocess data
        xyz = pc_4c[0]
        label = pc_4c[1].astype('uint8')
        pc_4c = (xyz, label)
        val_pt_fea, val_grid, val_grid_tenor, val_label_tensor = self.preprocess_data(pc_4c)

        # validation
        with torch.no_grad():
            predict_labels = self.my_model(val_pt_fea, val_grid_tenor)
            # the loss does not make sense, because the label for other classes is wrong
            loss = lovasz_softmax(torch.nn.functional.softmax(predict_labels).detach(), val_label_tensor, ignore=255) + self.loss_fun(predict_labels.detach(), val_label_tensor)
            predict_labels = torch.argmax(predict_labels, dim=1)
            # training label has no background
            predict_labels = predict_labels + 1

            # get the pointwise label, make all other prediction zero
            predict_pointwise_labels = predict_labels[0, val_grid[0][:, 0], val_grid[0][:, 1], val_grid[0][:, 2]]
            predict_pointwise_labels[predict_pointwise_labels != 1] = 0

        # calculate IoU
        iou = fg_iou_single(predict_pointwise_labels, torch.from_numpy(label[:, 0]).to(torch.device('cuda:0')))

        # save prediction
        #save_pc('./log/gt.ply', xyz, predict_pointwise_labels[:, None])
        #save_pc('./log/prediction.ply', xyz, predict_pointwise_labels[:, None])
        return iou.item(), loss.item(), predict_pointwise_labels

    def differentiable_attack(self, pc_4c):
        """
            The input varibale is still a pointcloud, but we need to preprocess it to a differentiable feature.
            This method can easily attack the model since it will cause a mismatch between the channels of the feature
        """

        # preprocess data
        xyz = pc_4c[0]
        label = pc_4c[1].astype('uint8')
        pc_4c = (xyz, label)
        val_pt_fea, val_grid, val_grid_tenor, val_label_tensor = self.preprocess_data(pc_4c)

        # only for xyz
        attack_z = Variable(0.001*torch.zeros((val_pt_fea[0][0], 3)), requires_grad=True).to(torch.device('cuda:0'))
        # enable gradient for input feature
        for i in range(len(val_pt_fea)):
            val_pt_fea[i].requires_grad = True 

        attack_itr = 100
        method = 'PGD'
        attack_eps = 0.001
        attack_ub = 0.001
        attack_lb = -0.001
        attack_eps = torch.tensor(attack_eps).to(torch.device('cuda:0'))
        attack_lb = torch.tensor(attack_lb).to(torch.device('cuda:0'))
        attack_ub = torch.tensor(attack_ub).to(torch.device('cuda:0'))

        for a_i in range(attack_itr):
            val_pt_fea[0] = val_pt_fea[0] + attack_z
            # validation
            predict_labels = self.my_model(val_pt_fea, val_grid_tenor)
            # the loss does not make sense, because the label for other classes is wrong
            loss = lovasz_softmax(torch.nn.functional.softmax(predict_labels), val_label_tensor, ignore=255) + self.loss_fun(predict_labels, val_label_tensor)
            loss.backward()

            if method == 'FGSM':
                attack_z.data += attack_eps * attack_z.grad.detach().sign()
            elif method == 'PGD':
                attack_z.data += attack_eps * attack_z.grad.detach()
                attack_z.data = torch.max(torch.min(attack_z.data, attack_ub),  attack_lb)
            elif method == 'PGD-N':
                # L2-normalized steepest descent 
                attack_delta = attack_z.grad.detach()/torch.norm(attack_z.grad, p=2)
                attack_z.data += attack_eps * attack_delta
                attack_z.data = torch.max(torch.min(attack_z.data,  attack_ub), attack_lb)
            else:
                NotImplementedError()
            attack_z.grad.zero_()
        
            # calculate iou
            predict_labels = torch.argmax(predict_labels, dim=1)
            predict_labels = predict_labels + 1
            # get the pointwise label, make all other prediction zero
            predict_pointwise_labels = predict_labels[0, val_grid[0][:, 0], val_grid[0][:, 1], val_grid[0][:, 2]]
            predict_pointwise_labels[predict_pointwise_labels != 1] = 0
            # calculate IoU
            iou = fg_iou_single(predict_pointwise_labels, torch.from_numpy(label[:, 0]).to(torch.device('cuda:0')))

            print(('Iter: [{}/{}] Loss: {:.6f} IoU: {:.6f}').format(a_i, attack_itr, loss.item(), iou.item()))

        return iou.item(), loss.item(), predict_pointwise_labels


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Traffic Scenario Generation")
    parser.add_argument("--test_name", type=str, default='pointnet2', help='')
    parser.add_argument("--type", type=str, default='sb_vae', help='')
    args = parser.parse_args()
    
    attack = Attack(model_name='polar')
    pc_4c = attack.prepare_pc(args.test_name, args.type)
    iou, loss, _ = attack.attack(pc_4c)
    print('IoU for vehicle:', iou)
    print('loss:', loss)
