import yaml
import numpy as np
from plyfile import PlyData, PlyElement
import pandas as pd
import sys
sys.path.append('../../')
import argparse

import torch

from tasks.semantic.modules.segmentator import Segmentator
from tasks.semantic.postproc.KNN import KNN
from tasks.semantic.dataset.kitti.parser import Data_Loader


def save_ply(save_path, points, text=True):
    # without color information
    if points.shape[1] == 3:
        points = [(points[i, 0], points[i, 1], points[i, 2]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
    else:
        points = [(points[i, 0], points[i, 1], points[i, 2], points[i, 3], points[i, 4], points[i, 5]) for i in range(points.shape[0])]
        vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
    el = PlyElement.describe(vertex, 'vertex', comments=['vertices'])
    PlyData([el], text=text).write(save_path)


def read_ply(filename):
    """ read XYZRGB point cloud from filename PLY file 
    """
    plydata = PlyData.read(filename).elements[0].data
    data_pd = pd.DataFrame(plydata) 
    data_np = np.zeros(data_pd.shape, dtype=np.float)
    property_names = plydata[0].dtype.names  
    for i, name in enumerate(property_names):
        data_np[:, i] = data_pd[name]
    return data_np


def save_pc(filename, one_xyz, one_label):
    rgb = np.zeros_like(one_xyz)
    for i in range(one_label.shape[0]):
        if one_label[i] == 1:
            rgb[i] = [255, 0, 0]
        else:
            rgb[i] = [255, 255, 255]
    xyzrgb = np.concatenate([one_xyz, rgb], axis=1)
    save_ply(filename, xyzrgb)


def fg_iou_single(pred_class, cls_labels):
    # calculate the IOU of vehicle
    # mIoU: the IoU score is calculated for each class separately and then averaged over all classes
    fg_mask = cls_labels > 0
    correct = ((pred_class == cls_labels) & fg_mask).float().sum()
    union = fg_mask.sum().float() + (pred_class > 0).sum().float() - correct
    iou = correct / torch.clamp(union, min=1.0)
    return iou


class Attack():
    def __init__(self):
        # parameters
        self.model_path = '/home/wenhao/2_pytorch_ws/17-Baselines/SqueezeSegV3/SSGV3-53'
        self.ARCH = yaml.safe_load(open(self.model_path + "/arch_cfg.yaml", 'r'))
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # get the data
        self.dataloader = Data_Loader(sensor=self.ARCH["dataset"]["sensor"], max_points=self.ARCH["dataset"]["max_points"], nclasses=20)

        # concatenate the encoder and the head
        with torch.no_grad():
            self.model = Segmentator(self.ARCH, self.dataloader.nclasses, self.model_path).to(self.device)
        self.model.eval()

        # use knn post processing
        self.post = None
        if self.ARCH["post"]["KNN"]["use"]:
            self.post = KNN(self.ARCH["post"]["KNN"]["params"], self.dataloader.nclasses)

        self.loss_fun = torch.nn.CrossEntropyLoss()

    def prepare_pc(self, test_name, type):
        # read pointcloud file
        #attack_pc = read_ply('../../data/0_point_attack_gt.ply')
        print(test_name, type)
        if type == 'point':
            if test_name == 'pointnet2':
                attack_pc = read_ply('../../../scene_examples and transferability/point_SimBA_attack_gt_pointnet2_intersection.ply')
            elif test_name == 'squeezeseg':
                attack_pc = read_ply('../../../scene_examples and transferability/point_SimBA_attack_gt_squeezeseg_intersection.ply')
            elif test_name == 'polarseg':
                attack_pc = read_ply('../../../scene_examples and transferability/point_SimBA_attack_gt_polarseg_intersection.ply')
            elif test_name == 'cyclinder3d':
                attack_pc = read_ply('../../../scene_examples and transferability/point_SimBA_attack_gt_cyclinder3d_intersection.ply')
        else:
            if test_name == 'pointnet2':
                attack_pc = read_ply('../../../scene_examples and transferability/sb_vae_BO_attack_gt_pointnet2_intersection.ply')
            elif test_name == 'squeezeseg':
                attack_pc = read_ply('../../../scene_examples and transferability/sb_vae_BO_attack_gt_squeezeseg_intersection.ply')
            elif test_name == 'polarseg':
                attack_pc = read_ply('../../../scene_examples and transferability/sb_vae_BO_attack_gt_polarseg_intersection.ply')
            elif test_name == 'cyclinder3d':
                attack_pc = read_ply('../../../scene_examples and transferability/sb_vae_BO_attack_gt_cyclinder3d_intersection.ply')

        xyz = attack_pc[:, 0:3]
        rgb = attack_pc[:, 3:6]
        label = ((255.0 - rgb[:, 1:2])/255.0).astype('uint8')
        return (xyz, label)

    def preprocess_data(self, pc_4c):
        proj_in, proj_mask, proj_labels, p_x, p_y, proj_range, unproj_range = self.dataloader.range_map_process(pc_4c)

        # first cut to rela size (batch size one allows it)
        p_x = p_x.to(self.device)
        p_y = p_y.to(self.device)
        proj_range = proj_range.to(self.device)
        unproj_range = unproj_range.to(self.device)
        proj_labels = proj_labels.to(self.device).long()

        proj_in = proj_in.to(self.device)[None]
        proj_mask = proj_mask.to(self.device)[None]
        return proj_in, proj_mask, p_x, p_y, proj_range, unproj_range, proj_labels

    def attack(self, pc_4c):
        # preprocess data
        label_torch = torch.from_numpy(pc_4c[1]).to(self.device)
        proj_in, proj_mask, p_x, p_y, proj_range, unproj_range, proj_labels = self.preprocess_data(pc_4c)

        with torch.no_grad():
            proj_output, z2, z3, z4, z5 = self.model(proj_in, proj_mask)
            proj_argmax = proj_output[0].argmax(dim=0) # [1, 20, 64, 2048] -> [64, 2048]

            if self.post:
                predict_labels = self.post(proj_range, unproj_range, proj_argmax, p_x, p_y)
            else:
                # put in original pointcloud using indexes
                predict_labels = proj_argmax[p_y, p_x]
            predict_labels[predict_labels != 1] = 0

        # calculate IoU and loss
        iou = fg_iou_single(predict_labels, label_torch[:, 0])
        loss = self.loss_fun(proj_output, proj_labels[None])

        # save prediction
        #save_pc('./log/gt.ply', pc_4c[0], predict_pointwise_labels[:, None])
        #save_pc('./log/prediction.ply', pc_4c[0], predict_labels)
        return iou.item(), loss.item(), predict_labels


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Traffic Scenario Generation")
    parser.add_argument("--test_name", type=str, default='pointnet2', help='')
    parser.add_argument("--type", type=str, default='sb_vae', help='')
    args = parser.parse_args()
    
    attack = Attack()
    pc_4c = attack.prepare_pc(args.test_name, args.type)
    iou, loss, _ = attack.attack(pc_4c)
    print('IoU for vehicle:', iou)
    print('loss:', loss)
