import torch
from tqdm import tqdm
import os
import numpy as np
import logging
import pipelines_sapien as pipelines
import Art_DataGen
import open3d as o3d


def get_instance_filenames(data_source, split, gt_filename="sdf_data.csv", filter_modulation_path=None):
    do_filter = filter_modulation_path is not None
    npzfiles = []
    for dataset in split:
        dataset = dataset.replace(".json", ".npz")
        instance_filename = os.path.join(data_source, dataset)

        if do_filter:
            mod_file = os.path.join(filter_modulation_path, dataset, "latent.txt")

            
            if not os.path.isfile(mod_file):
                continue

        if not os.path.isfile(instance_filename):
            logging.warning("Requested non-existent file '{}'".format(instance_filename))
            continue

        npzfiles.append(instance_filename)
    return npzfiles

def instances_dict(path):

    folder_dict = {}

    for i, folder_name in enumerate(sorted(os.listdir(path)), start=1):
        folder_path = os.path.join(path, folder_name)
        if os.path.isdir(folder_path):
            folder_dict[folder_name] = i

    return folder_dict

def labeled_sampling(f, subsample, pc_size=1024, load_from_path=True, label='sdf'):
    if load_from_path:
        if label == 'sdf':
            f = torch.from_numpy(np.load(f)['sdf_data'])
        elif label == 'grid':
            f = torch.from_numpy(np.load(f)['grid_data'])
    half = int(subsample / 2)
    neg_tensor = f[f[:,-1]<0]
    pos_tensor = f[f[:,-1]>0]

    if pos_tensor.shape[0] < half:
        if pos_tensor.shape[0]==0:
            pos_idx = torch.randint(0, neg_tensor.shape[0], (half,))
        else:
            pos_idx = torch.randint(0, pos_tensor.shape[0], (half,))
    else:
        pos_idx = torch.randperm(pos_tensor.shape[0])[:half]

    if neg_tensor.shape[0] < half:
        if neg_tensor.shape[0]==0:
            neg_idx = torch.randperm(pos_tensor.shape[0])[:half]
        else:
            neg_idx = torch.randint(0, neg_tensor.shape[0], (half,))
    else:
        neg_idx = torch.randperm(neg_tensor.shape[0])[:half]

    if pos_tensor.shape[0]==0:
        pos_sample = neg_tensor[pos_idx]
    else:
        pos_sample = pos_tensor[pos_idx]

    if neg_tensor.shape[0]==0:
        neg_sample = pos_tensor[neg_idx]
    else:
        neg_sample = neg_tensor[neg_idx]

    pc = f[f[:,-1]==0][:,:3]
    pc_idx = torch.randperm(pc.shape[0])[:pc_size]
    pc = pc[pc_idx]

    samples = torch.cat([pos_sample, neg_sample], 0)

    return pc.float().squeeze(), samples[:,:3].float().squeeze(), samples[:, 3].float().squeeze() 

if __name__ == "__main__":
    class_name = 'laptop'
    data_source = f'data/ArtImage/{class_name}/train/Sdf_data/sdf_data'
    grid_source = f'ArtImage/{class_name}/train/Sdf_data/grid_data'
    TrainSplit_path = f'data/ArtImage/{class_name}/test.txt'
    urdf_path = f'data/ArtImage/{class_name}/urdf'
    instance_dict = instances_dict(urdf_path)
    split_file = [line.strip() for line in open(TrainSplit_path, "r").readlines()] 
    samples_per_mesh=16000
    pc_size= 30000
    modulation_path=None 
    class_name = data_source.split("/")[-4]
    pc_size = pc_size                    
    gt_files = get_instance_filenames(data_source, split_file, filter_modulation_path=modulation_path)

    subsample = len(gt_files)
    gt_files = gt_files[0:subsample]

    grid_source = grid_source
    
    grid_files = get_instance_filenames(grid_source, split_file, gt_filename="grid_gt.csv",
                                                  filter_modulation_path=modulation_path)
    grid_files = grid_files[0:subsample]

    assert len(grid_files) == len(gt_files)

    near_surface_count = int(samples_per_mesh * 0.7) if grid_source else samples_per_mesh
    
    print("loading all {} files into memory...".format(len(gt_files)))
    data_dicts = []
    with tqdm(gt_files) as pbar:
        for i, f in enumerate(pbar):
            json_id = f.split('/')[-1].replace('npz', 'json')
            id = json_id.split('.')[0]
            ann_path = os.path.join(f'data/ArtImage/{class_name}/train/annotations', json_id)
            img_prefix = ann_path.rsplit('/annotations', 1)[0]
            joint_param_path = f'data/ArtImage/urdf_metas/{class_name}/urdf_metas.json'
            results = {}
            root = 'data/ArtImage'
            
            results['camera_intrinsic_path'] = os.path.join(root, 'camera_intrinsic.json')
            results['img_prefix'] = img_prefix

            results = Art_DataGen.fecth_instances(results, ann_path)
            results = Art_DataGen.fetch_joint_params(results, joint_param_path, class_name)
            results = Art_DataGen.fetch_rest_trans(class_name, results['urdf_id'], results)

            pbar.set_description("Files loaded: {}/{}".format(i, len(gt_files)))
            pc, sdf_xyz, sdf_gt = labeled_sampling(gt_files[i], near_surface_count, pc_size,
                                                        load_from_path=True, label='sdf')
            if grid_source is not None:
                grid_count = samples_per_mesh - near_surface_count
                _, grid_xyz, grid_gt = labeled_sampling(grid_files[i], grid_count, pc_size=0,
                                                             load_from_path=True,label='grid')

            points_data_path = f'data/ArtImage/{class_name}/train/points_data/{id}.npz'
            points_data = np.load(points_data_path, allow_pickle=True)['points_data']
            data_dict = points_data.item()

            part_pc_path = f'data/ArtImage/{class_name}/train/part_pc/{id}'
            part_seg = [None] * results['n_parts']
            for part_id in range(results['n_parts']):
                full_pc_path = os.path.join(part_pc_path, f'{part_id}.npz')
                full_pc = np.load(full_pc_path)['part_pc']
                labels = np.full(full_pc.shape[0], part_id)
                pc_with_label = np.column_stack((full_pc, labels))
                part_seg[part_id] = pc_with_label

            merged_array = np.vstack(part_seg)

            full_point_cloud = merged_array[:, :3]
            full_seg = merged_array[:, 3]

            camera_partial_pc = data_dict['camera_partial_pc']
            canonical_partial_pc = data_dict['canonical_partial_pc']
            cls = data_dict['cls']
            base_pose = data_dict['base_pose']
            atc = data_dict['atc']
            joint_type_gt = data_dict['joint_type_gt']
            joint_xyz = data_dict['joint_xyz']
            joint_rpy = data_dict['joint_rpy']
            for i in range(results['n_parts']):
                if i==0:
                    continue
                tran = base_pose
                joint_xyz[i] = np.dot(joint_xyz[i], tran[:3,:3].T)+tran[:3,3]
                joint_rpy[i] = np.dot(joint_rpy[i], tran[:3,:3].T)

            pc_idx = torch.randperm(camera_partial_pc.shape[0])[:pc_size]
            camera_partial_pc = torch.from_numpy(camera_partial_pc[pc_idx])
            canonical_partial_pc = torch.from_numpy(canonical_partial_pc[pc_idx])
            cls = torch.from_numpy(cls[pc_idx])
            base_pose = torch.from_numpy(base_pose)
            
            joint_xyz = torch.from_numpy(np.array(joint_xyz))
            joint_rpy = torch.from_numpy(np.array(joint_rpy))
            atc = torch.from_numpy(atc)

            pc_idx = torch.randperm(full_point_cloud.shape[0])[:pc_size]
            full_point_cloud = torch.from_numpy(full_point_cloud[pc_idx])
            full_seg = torch.from_numpy(full_seg[pc_idx])

            sdf_xyz = torch.cat((sdf_xyz, grid_xyz))
            sdf_gt = torch.cat((sdf_gt, grid_gt))
            data_dict = {
                "xyz": sdf_xyz.float().squeeze(),
                "gt_sdf": sdf_gt.float().squeeze(),
                "point_cloud": pc.float().squeeze(),
                "camera_partial_pc": camera_partial_pc.float().squeeze(),
                "canonical_partial_pc": canonical_partial_pc.float().squeeze(),
                "seg": seg,
                "cls": cls,
                "base_pose": base_pose.float().squeeze(),
                "joint_xyz": joint_xyz,
                "joint_rpy": joint_rpy,
                "atc": atc
            }
            base_name = os.path.basename(f)
            print(base_name)
            sdf_dir = os.path.join(f'data/ArtImage/{class_name}/train/diffusion_sdf', base_name)

            numpy_data_dict = {}
            for key, value in data_dict.items():
                if isinstance(value, torch.Tensor):
                    numpy_data_dict[key] = value.cpu().numpy()
                else:
                    numpy_data_dict[key] = value

            np.savez(sdf_dir, **numpy_data_dict)

            data_dicts.append(data_dict)