import argparse
import time
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import open3d as o3d
from data_utils.ModelNetDataLoader40 import ModelNetDataLoader40
from data_utils.ModelNetDataLoader10 import ModelNetDataLoader10
from data_utils.ShapeNetDataLoader import PartNormalDataset
from data_utils.KITTIDataLoader import KITTIDataLoader
from data_utils.ScanObjectNNDataLoader import ScanObjectNNDataLoader
from torch.utils.data import DataLoader, TensorDataset

from utils.logging import Logging_str
from utils.utils import set_seed,class_wise_rot_sca,class_wise_transformation
import math
from utils.utils import show_time, transform_time
import os
import sys
import importlib
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'model/classifier'))


from model_utils.riconv2_util import compute_LRA

 



def load_data(args, data_path):
    """Load the dataset from the given path.
    """
    if args.dataset == 'ModelNet40':
        DATASET = ModelNetDataLoader40(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
            normal_channel=False
        )

    elif args.dataset == 'ModelNet10':
        DATASET = ModelNetDataLoader10(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
            normal_channel=False
        )
    elif args.dataset == 'ShapeNetPart':
        DATASET = PartNormalDataset(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
            normal_channel=False
        )
    elif args.dataset == 'kitti':
        DATASET = KITTIDataLoader(
            root=data_path,
            npoints=256,
            split='train',
        )
    elif args.dataset == 'ScanObjectNN':
        DATASET = ScanObjectNNDataLoader(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
        )
    else:
        raise NotImplementedError

    T_DataLoader = torch.utils.data.DataLoader(
        DATASET,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers
    )
    print('Finish Loading Dataset...')
    return T_DataLoader

def data_preprocess(data):
    """Preprocess the given data and label.
    """
    points, target = data

    points = points # [B, N, C]
    target = target[:, 0] # [B]

    return points, target

def save_tensor_as_txt(args, points, filename):  
    """Save the torch tensor into a txt file.
    """
    points = points.squeeze(0)
    file_path = os.path.join(args.output_dir,'example')
    if not os.path.exists(file_path):
        os.makedirs(file_path)
    with open(os.path.join(file_path,filename), "w") as f:
        for i in range(points.shape[0]):
            # msg = str(points[i][0]) + ' ' + str(points[i][1]) + ' ' + str(points[i][2])
            msg = str(points[i][0].item()) + ' ' + str(points[i][1].item()) + ' ' + str(points[i][2].item()) + \
                ' ' + str(points[i][3].item()) +' ' + str(points[i][4].item()) + ' '+ str(points[i][5].item())
            # msg = str(points[i][0]) + ' ' + str(points[i][1]) + ' ' + str(points[i][2])
            """ msg = str(points[i][0]) + ' ' + str(points[i][1]) + ' ' + str(points[i][2]) + \
                ' ' + str(points[i][3].item()) +' ' + str(points[i][3].item()) + ' '+ str(1-points[i][3].item())
            file_object.write(msg+'\n') """
            f.write(msg+'\n')
        f.close()
 

def get_list(mode, args):
    list = []
    if mode == 'rot':
        Avg_num = math.ceil(args.NUM_CLASSES ** (1 / 3))
        x_list, y_list = [random.uniform(0, args.slight_range) for _ in range(Avg_num)], [random.uniform(0, args.slight_range) for _ in range(Avg_num)]
        z_list = [random.uniform(0, args.main_range) for _ in range(Avg_num)]

        for i in range(Avg_num):
            for j in range(Avg_num):
                for k in range(Avg_num):
                    list.append([x_list[i], z_list[j], y_list[k]])

        list = random.sample(list, args.NUM_CLASSES)

    elif mode == 'shear':
        # Avg_num = math.ceil(NUM_CLASSES ** (1 / 2))
        x_list, y_list = [random.uniform(0, 0.4) for _ in range(args.NUM_CLASSES)], [random.uniform(0, 0.4) for _ in range(args.NUM_CLASSES)]
        for i in range(args.NUM_CLASSES):
            for j in range(args.NUM_CLASSES):
                list.append([x_list[i], y_list[j]])
        # list = random.sample(list, NUM_CLASSES)   
    
    elif mode == 'scale':
        list = [random.uniform(args.sca_min, args.sca_max) for _ in range(args.NUM_CLASSES)]

    elif mode == 'twist':
        list = [random.uniform(0, 20) for _ in range(args.NUM_CLASSES)]

    elif mode == 'taper':
        list = [random.uniform(0, 50) for _ in range(args.NUM_CLASSES)]

    elif mode == 'translation':
        list = [[random.uniform(0, 0.3), random.uniform(0, 0.3), random.uniform(0, 0.3)] for _ in range(args.NUM_CLASSES)]
    return list


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Unleanrable Point Clouds')
    parser.add_argument('--emb_dims', type=int, default=1024, metavar='N', help='Dimension of embeddings')
    parser.add_argument('--batch_size', type=int, default=1, metavar='N', help='input batch size for training (default: 1)')
    parser.add_argument('--input_point_nums', type=int, default=1024, help='Point nums of each point cloud')
    parser.add_argument('--seed', type=int, default=2022, metavar='S', help='random seed (default: 2022)')
    parser.add_argument('--dataset', type=str, default='ModelNet10', choices=['ModelNet10', 'ModelNet40', 'ShapeNetPart', 'kitti', 'ScanObjectNN'])
    parser.add_argument('--target_model', type=str, default='pointnet_cls', choices=['pointnet_cls', 'pointnet2_cls_msg', 'dgcnn', 'pointconv', 'pointcnn', 'paconv', 'pct', 'curvenet', 'simple_view'])
    parser.add_argument('--num_workers', type=int, default=4, help='Worker nums of data loading.')
    parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]')
    parser.add_argument('--slight_range', type=int, default=15, help='x,y angle range [para 1]')
    parser.add_argument('--main_range', type=int, default=120, help='z angle range [para 2]')
    parser.add_argument('--sca_min', type=float, default=0.6, help='scale min bound [para 3]')
    parser.add_argument('--sca_max', type=float, default=0.8, help='scale max bound [para 4]')
    parser.add_argument('--mode', type=str)
    parser.add_argument('--NUM_CLASSES', type=int, default=10)
# 25
# 90 240
# 0.9
# 1.2

    args = parser.parse_args()
    args.device = torch.device("cuda")

    set_seed(args.seed)
    if args.dataset == 'ModelNet40':
        args.NUM_CLASSES = 40
        data_path = "./data/modelnet40_normal_resampled"

    elif args.dataset == 'ShapeNetPart':
        args.NUM_CLASSES = 16
        data_path = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal/'

    elif args.dataset == 'ModelNet10':
        args.NUM_CLASSES = 10
        data_path = "./data/modelnet40_normal_resampled"
    elif args.dataset == 'kitti':
        args.NUM_CLASSES = 2
        data_path = '/HARD-DATA/LW/DATA/KITTI/training/object_cloud'
    elif args.dataset == 'ScanObjectNN':
        args.NUM_CLASSES = 15
        data_path ='data/h5_files'
    assert args.NUM_CLASSES != 0
    args.output_dir = os.path.join("./UMT_normals", args.dataset, str(args.slight_range) + '_' + str(args.main_range) + '_' + str(args.sca_min) + '_' + str(args.sca_max) + "_" + str(20) + "_" + str(0.4))


    data_loader = load_data(args, data_path)
    log = Logging_str(os.path.join(args.output_dir, 'log.txt'))
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))

    start = time.time()
    st_hour, st_min = show_time(start)


    import ast
    args.mode = ast.literal_eval(args.mode)
    mode_list = {m : get_list(m, args) for m in args.mode}

    # Avg_num = math.ceil(num_class ** (1 / 3))
    # x_list, y_list = [random.uniform(0, args.slight_range) for _ in range(Avg_num)], [random.uniform(0, args.slight_range) for _ in range(Avg_num)]
    # z_list = [random.uniform(0, args.main_range) for _ in range(Avg_num)]
    # rotation_list = []
    # for i in range(Avg_num):
    #     for j in range(Avg_num):
    #         for k in range(Avg_num):
    #             rotation_list.append([x_list[i], z_list[j], y_list[k]])

    # rotation_list = random.sample(rotation_list, num_class)
    # scale_list = [random.uniform(args.sca_min, args.sca_max) for _ in range(num_class)]

    # with open(os.path.join(args.output_dir, 'log.txt'), "w") as file:
    #     file.write("[" + ", ".join(map(str, rotation_list)) + "]\n")
    #     file.write("[" + ", ".join(map(str, scale_list)) + "]\n")

    for batch_id, data in pbar:
        if args.dataset == 'ShapeNetPart':
            data = data[:2]

        data, label = data_preprocess(data) 
        for idx in range(len(data)):
            trans_data = data[idx].clone().detach()
            for k, v in mode_list.items():
                trans_data = torch.tensor(class_wise_transformation(trans_data, k, v, label[idx].item()))
            data[idx] = trans_data    
 

        # points, target = data_preprocess(data)
        # target = target.long()
        # transformed_pc = points.clone().detach().cpu().numpy()
        # for idx in range(len(transformed_pc)):
        #     for k, v in mode_list.items():
        #         trans_data = transformed_pc[idx]
        #         trans_data = class_wise_transformation(trans_data, k, v, target[idx].item())
        #     transformed_pc[idx] = trans_data
        # transformed_pc = torch.tensor(transformed_pc)
        # normal = compute_LRA(transformed_pc)
        # transformed_pc = torch.cat([transformed_pc, normal], axis=-1)

        normal = compute_LRA(data)


        data = torch.cat([data, normal], axis=-1)


        save_tensor_as_txt(args, data.detach().cpu().numpy(), f'{batch_id}_transform_{label.item()}.txt')
        # save_tensor_as_txt(args, points, f'{batch_id}_origin_{target.item()}.txt')
        # from visual_util import plot_pcd_three_views
        # titles = ['viewpoint 1', 'viewpoint 2', 'viewpoint 3']
        # file_path = os.path.join(args.output_dir, 'fig')
        # # file_path = '/HARD-DATA/WXL/projects/Poisoning-Point-Cloud-main/3D_transformation'
        # if not os.path.exists(file_path):
        #     os.makedirs(file_path)
        # plot_pcd_three_views(os.path.join(file_path, f'{batch_id}_origin_{target.item()}.png'), [points.squeeze(0)], titles, zdir='y')
        # plot_pcd_three_views(os.path.join(file_path, f'{batch_id}_transform_{target.item()}.png'), [transformed_pc.squeeze(0)], titles, zdir='y')


    end = time.time()
    t_hour, t_min = show_time(end)
    spent_hour, spent_min = transform_time(start, end)
    print("The spent time is {} h{} min".format(spent_hour, spent_min))
    log.write('Time: %f' %(end - start))