# tools/build_cache_raw.py
"""
python build_cache_raw.py --src DATA_DIR/semantickitti --dst DATA_DIR/semantickitti_raw --ds SemanticKITTI
python build_cache_raw.py --src DATA_DIR/SynLiDAR --dst DATA_DIR/SynLiDAR_raw --ds SynLiDAR
python build_cache_raw.py --src DATA_DIR/semanticposs --dst DATA_DIR/semanticposs_raw --ds SemanticPOSS
"""
import argparse, os, tqdm, numpy as np, yaml
from utils.data_process import DataProcessing as DP
from dataset.data_utils import get_sk_data

def build(dataset, src_dir, dst_dir):

    if dataset == 'SemanticKITTI':
        DATA = yaml.safe_load(open('utils/semantic-kitti.yaml'))
        lut  = np.zeros(max(DATA['learning_map'].keys())+100, np.int32)
        lut[list(DATA['learning_map'])] = list(DATA['learning_map'].values())
        seq_num = 10
        # DATA = yaml.safe_load(open('utils/annotations.yaml'))
        # lut  = np.zeros(max(DATA['map_2_semantickitti'].keys())+100, np.int32)
        # lut[list(DATA['map_2_semantickitti'])] = list(DATA['map_2_semantickitti'].values())
    elif dataset == 'SemanticPOSS':
        DATA = yaml.safe_load(open('utils/semantic-poss.yaml'))
        lut  = np.zeros(max(DATA['learning_map'].keys())+100, np.int32)
        lut[list(DATA['learning_map'])] = list(DATA['learning_map'].values())
        seq_num = 5
    elif dataset == 'SynLiDAR':
        DATA = yaml.safe_load(open('utils/synlidar.yaml'))
        lut  = np.zeros(max(DATA['learning_map'].keys())+100, np.int32)
        lut[list(DATA['learning_map'])] = list(DATA['learning_map'].values())
        seq_num = 12
    else:
        raise NotImplementedError(f'Unknown dataset: {dataset}')

    file_list = DP.get_file_list(src_dir + '/sequences', [f'{i:02d}' for i in range(seq_num + 1)])
    for key in tqdm.tqdm(file_list):
        out = os.path.join(dst_dir + '/sequences', key[0], key[1]+'.npz')
        if os.path.exists(out):
            continue
        xyz, rem, lbl = get_sk_data(key, src_dir + '/sequences', lut, dataset)
        os.makedirs(os.path.dirname(out), exist_ok=True)
        np.savez_compressed(out, xyz=xyz.astype(np.float32),
                                 rem=rem.astype(np.float32),
                                 lbl=lbl.astype(np.int32))
if __name__ == '__main__':
    ap = argparse.ArgumentParser()
    ap.add_argument('--src', required=True); ap.add_argument('--dst', required=True)
    ap.add_argument('--ds',  default='SemanticKITTI')
    a = ap.parse_args(); build(a.ds, a.src, a.dst)