import argparse
import time
import random
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import open3d as o3d
from data_utils.ModelNetDataLoader40 import ModelNetDataLoader40
from data_utils.ModelNetDataLoader10 import ModelNetDataLoader10
from data_utils.ShapeNetDataLoader import PartNormalDataset
from data_utils.KITTIDataLoader import KITTIDataLoader
from data_utils.ScanObjectNNDataLoader import ScanObjectNNDataLoader
from torch.utils.data import DataLoader, TensorDataset

from utils.logging import Logging_str
from utils.utils import set_seed,class_wise_rot_sca
import math
from utils.utils import show_time, transform_time
import os
import sys
import importlib
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'model/classifier'))




def load_data(args, data_path):
    """Load the dataset from the given path.
    """
    if args.dataset == 'ModelNet40':
        DATASET = ModelNetDataLoader40(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
            normal_channel=False
        )

    elif args.dataset == 'ModelNet10':
        DATASET = ModelNetDataLoader10(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
            normal_channel=False
        )
    elif args.dataset == 'ShapeNetPart':
        DATASET = PartNormalDataset(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
            normal_channel=False
        )
    elif args.dataset == 'kitti':
        DATASET = KITTIDataLoader(
            root=data_path,
            npoints=256,
            split='train',
        )
    elif args.dataset == 'ScanObjectNN':
        DATASET = ScanObjectNNDataLoader(
            root=data_path,
            npoint=args.input_point_nums,
            split='train',
        )
    else:
        raise NotImplementedError

    T_DataLoader = torch.utils.data.DataLoader(
        DATASET,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers
    )
    print('Finish Loading Dataset...')
    return T_DataLoader

def data_preprocess(data):
    """Preprocess the given data and label.
    """
    points, target = data

    points = points # [B, N, C]
    target = target[:, 0] # [B]

    points = points.cuda()
    target = target.cuda()

    return points, target

def save_tensor_as_txt(args, points, filename):  
    """Save the torch tensor into a txt file.
    """
    points = points.squeeze(0)
    file_path = os.path.join(args.output_dir,'example')
    if not os.path.exists(file_path):
        os.makedirs(file_path)
    with open(os.path.join(file_path,filename), "w") as f:
        for i in range(points.shape[0]):
            msg = str(points[i][0]) + ' ' + str(points[i][1]) + ' ' + str(points[i][2])
            """ msg = str(points[i][0]) + ' ' + str(points[i][1]) + ' ' + str(points[i][2]) + \
                ' ' + str(points[i][3].item()) +' ' + str(points[i][3].item()) + ' '+ str(1-points[i][3].item())
            file_object.write(msg+'\n') """
            f.write(msg+'\n')
        f.close()
 



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Unleanrable Point Clouds')
    parser.add_argument('--emb_dims', type=int, default=1024, metavar='N', help='Dimension of embeddings')
    parser.add_argument('--batch_size', type=int, default=1, metavar='N', help='input batch size for training (default: 1)')
    parser.add_argument('--input_point_nums', type=int, default=1024, help='Point nums of each point cloud')
    parser.add_argument('--seed', type=int, default=2022, metavar='S', help='random seed (default: 2022)')
    parser.add_argument('--dataset', type=str, default='ShapeNetPart', choices=['ModelNet10', 'ModelNet40', 'ShapeNetPart', 'kitti', 'ScanObjectNN'])
    parser.add_argument('--target_model', type=str, default='pointnet_cls', choices=['pointnet_cls', 'pointnet2_cls_msg', 'dgcnn', 'pointconv', 'pointcnn', 'paconv', 'pct', 'curvenet', 'simple_view'])
    parser.add_argument('--num_workers', type=int, default=4, help='Worker nums of data loading.')
    parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]')
    parser.add_argument('--slight_range', type=int, default=15, help='x,y angle range [para 1]')
    parser.add_argument('--main_range', type=int, default=120, help='z angle range [para 2]')
    parser.add_argument('--sca_min', type=float, default=1.0, help='scale min bound [para 3]')
    parser.add_argument('--sca_max', type=float, default=1.0, help='scale max bound [para 4]')
# 25
# 90 240
# 0.9
# 1.2

    args = parser.parse_args()
    args.device = torch.device("cuda")

    set_seed(args.seed)
    if args.dataset == 'ModelNet40':
        num_class = 40
        data_path = "./data/modelnet40_normal_resampled"

    elif args.dataset == 'ShapeNetPart':
        num_class = 16
        data_path = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal/'

    elif args.dataset == 'ModelNet10':
        num_class = 10
        data_path = "./data/modelnet40_normal_resampled"
    elif args.dataset == 'kitti':
        num_class = 2
        data_path = '/HARD-DATA/LW/DATA/KITTI/training/object_cloud'
    elif args.dataset == 'ScanObjectNN':
        num_class = 15
        data_path ='data/h5_files'
    assert num_class != 0
    args.output_dir = os.path.join("./UDs", args.dataset, str(args.slight_range) + '_' + str(args.main_range) + '_' + str(args.sca_min) + '_' + str(args.sca_max))


    data_loader = load_data(args, data_path)
    log = Logging_str(os.path.join(args.output_dir, 'log.txt'))
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))

    start = time.time()
    st_hour, st_min = show_time(start)


    Avg_num = math.ceil(num_class ** (1 / 3))
    x_list, y_list = [random.uniform(0, args.slight_range) for _ in range(Avg_num)], [random.uniform(0, args.slight_range) for _ in range(Avg_num)]
    z_list = [random.uniform(0, args.main_range) for _ in range(Avg_num)]
    rotation_list = []
    for i in range(Avg_num):
        for j in range(Avg_num):
            for k in range(Avg_num):
                rotation_list.append([x_list[i], z_list[j], y_list[k]])

    rotation_list = random.sample(rotation_list, num_class)
    scale_list = [random.uniform(args.sca_min, args.sca_max) for _ in range(num_class)]

    # with open(os.path.join(args.output_dir, 'log.txt'), "w") as file:
    #     file.write("[" + ", ".join(map(str, rotation_list)) + "]\n")
    #     file.write("[" + ", ".join(map(str, scale_list)) + "]\n")

    for batch_id, data in pbar:
        if args.dataset == 'ShapeNetPart':
            data = data[:2]
        points, target = data_preprocess(data)
        target = target.long()
        transformed_pc = class_wise_rot_sca(points, target.item(), rotation_list, scale_list)

        save_tensor_as_txt(args, transformed_pc, f'{batch_id}_transform_{target.item()}.txt')
        save_tensor_as_txt(args, points, f'{batch_id}_origin_{target.item()}.txt')
        # from visual_util import plot_pcd_three_views
        # titles = ['viewpoint 1', 'viewpoint 2', 'viewpoint 3']
        # file_path = os.path.join(args.output_dir, 'fig') 
        # if not os.path.exists(file_path):
        #     os.makedirs(file_path)
        # plot_pcd_three_views(os.path.join(file_path, f'{batch_id}_origin_{target.item()}.png'), [points.squeeze(0).detach().cpu().numpy()], titles, zdir='y')
        # plot_pcd_three_views(os.path.join(file_path, f'{batch_id}_transform_{target.item()}.png'), [transformed_pc.squeeze(0)], titles, zdir='y')


    end = time.time()
    t_hour, t_min = show_time(end)
    spent_hour, spent_min = transform_time(start, end)
    print("The spent time is {} h{} min".format(spent_hour, spent_min))
    log.write('Time: %f' %(end - start))