# @package datamodule
defaults:
  - /datamodule/semantic/default.yaml

_target_: src.datamodules.scannet.ScanNetDataModule

dataloader:
    batch_size: 4

# These parameters are not actually used by the DataModule, but are used
# here to facilitate model parameterization with config interpolation
num_classes: 20
stuff_classes: [0, 1]
trainval: False
xy_tiling: null

# Features that will be computed, saved, loaded for points and segments

# point features used for the partition
partition_hf:
  - 'rgb'
  - 'linearity'
  - 'planarity'
  - 'scattering'
  - 'verticality'
  - 'elevation'

# point features used for training
point_hf:
  - 'linearity'
  - 'planarity'
  - 'scattering'
  - 'verticality'
  - 'elevation'
  - 'rgb'

# segment-wise features computed at preprocessing
segment_base_hf: []

# segment features computed as the mean of point feature in each
# segment, saved with "mean_" prefix
segment_mean_hf: []

# segment features computed as the std of point feature in each segment,
# saved with "std_" prefix
segment_std_hf: []

# horizontal edge features used for training
edge_hf:
  - 'mean_off'
  - 'std_off'
  - 'mean_dist'
  - 'angle_source'
  - 'angle_target'
  - 'centroid_dir'
  - 'centroid_dist'
  - 'normal_angle'
  - 'log_length'
  - 'log_surface'
  - 'log_volume'
  - 'log_size'

v_edge_hf: []  # vertical edge features used for training

# Parameters declared here to facilitate tuning configs without copying
# all the pre_transforms

# Based on SPG: https://arxiv.org/pdf/1711.09869.pdf
voxel: 0.02
knn: 45
knn_r: 2
knn_step: -1
knn_min_search: 25
ground_threshold: 1.5
ground_scale: 4.0
pcp_regularization: [0.01, 0.1, 0.5]
pcp_spatial_weight: [1e-1, 1e-1, 1e-1]
pcp_cutoff: [10, 10, 10]
pcp_k_adjacency: 10
pcp_w_adjacency: 1
pcp_iterations: 15
graph_k_min: 1
graph_k_max: 30
graph_gap: [0.2, 0.5, 1]
graph_se_ratio: 0.3
graph_se_min: 20
graph_cycles: 3
graph_margin: 0.2
graph_chunk: [1e6, 1e5, 1e5]  # reduce if CUDA memory errors

# Batch construction parameterization
sample_segment_ratio: 0.1
sample_segment_by_size: True
sample_segment_by_class: False
sample_point_min: 32
sample_point_max: 128
sample_graph_r: -1  # set to r<=0 to skip SampleRadiusSubgraphs
sample_graph_k: 4
sample_graph_disjoint: True
sample_edge_n_min: -1  # [5, 5, 15]
sample_edge_n_max: -1  # [10, 15, 25]

# Augmentations parameterization
pos_jitter: 0.03
tilt_n_rotate_phi: 0.1
tilt_n_rotate_theta: 180
anisotropic_scaling: 0.2
rgb_jitter: 0
rgb_autocontrast: 0.2
rgb_drop: 0.2
node_feat_jitter: 0.01
h_edge_feat_jitter: 0.01
v_edge_feat_jitter: 0
node_feat_drop: 0.1
h_edge_feat_drop: 0
v_edge_feat_drop: 0
node_row_drop: 0
h_edge_row_drop: 0
v_edge_row_drop: 0
drop_to_mean: False

# Preprocessing
pre_transform:
    - transform: SaveNodeIndex
      params:
        key: 'sub'
    - transform: DataTo
      params:
        device: 'cuda'
    - transform: GridSampling3D  # might OOM on CUDA if voxel and GPU memory too small
      params:
        size: ${datamodule.voxel}
        hist_key: 'y'
        hist_size: ${eval:'${datamodule.num_classes} + 1'}
    - transform: KNN
      params:
        k: ${datamodule.knn}
        r_max: ${datamodule.knn_r}
        verbose: False
    - transform: DataTo
      params:
        device: 'cpu'
    - transform: PointFeatures
      params:
        keys: ${datamodule.point_hf_preprocess}
        k_min: 1
        k_step: ${datamodule.knn_step}
        k_min_search: ${datamodule.knn_min_search}
        overwrite: False
    - transform: GroundElevation
      params:
        z_threshold: ${datamodule.ground_threshold}
        scale: ${datamodule.ground_scale}
    - transform: DataTo
      params:
        device: 'cuda'
    - transform: AdjacencyGraph
      params:
        k: ${datamodule.pcp_k_adjacency}
        w: ${datamodule.pcp_w_adjacency}
    - transform: ConnectIsolated
      params:
        k: 1
    - transform: DataTo
      params:
        device: 'cpu'
    - transform: AddKeysTo  # move some features to 'x' to be used for partition
      params:
        keys: ${datamodule.partition_hf}
        to: 'x'
        delete_after: False
    - transform: CutPursuitPartition
      params:
        regularization: ${datamodule.pcp_regularization}
        spatial_weight: ${datamodule.pcp_spatial_weight}
        k_adjacency: ${datamodule.pcp_k_adjacency}
        cutoff: ${datamodule.pcp_cutoff}
        iterations: ${datamodule.pcp_iterations}
        parallel: True
        verbose: False
    - transform: NAGRemoveKeys  # remove 'x' used for partition (features are still preserved under their respective Data attributes)
      params:
        level: 'all'
        keys: 'x'
    - transform: NAGTo
      params:
        device: 'cuda'
    - transform: SegmentFeatures
      params:
        n_min: 32
        n_max: 128
        keys: ${datamodule.segment_base_hf_preprocess}
        mean_keys: ${datamodule.segment_mean_hf_preprocess}
        std_keys: ${datamodule.segment_std_hf_preprocess}
        strict: False  # will not raise error if a mean or std key is missing
    - transform: RadiusHorizontalGraph
      params:
        k_min: ${datamodule.graph_k_min}
        k_max: ${datamodule.graph_k_max}
        gap: ${datamodule.graph_gap}
        se_ratio: ${datamodule.graph_se_ratio}
        se_min: ${datamodule.graph_se_min}
        cycles: ${datamodule.graph_cycles}
        margin: ${datamodule.graph_margin}
        chunk_size: ${datamodule.graph_chunk}
        halfspace_filter: True
        bbox_filter: True
        target_pc_flip: True
        source_pc_sort: False
        keys: ['mean_off', 'std_off', 'mean_dist' ]
    - transform: NAGTo
      params:
        device: 'cpu'

# CPU-based train transforms
train_transform: null

# CPU-based val transforms
val_transform: ${datamodule.train_transform}

# CPU-based test transforms
test_transform: ${datamodule.val_transform}

# GPU-based train transforms
on_device_train_transform:

    # Add a `node_size` attribute to all segments, this is needed for
    # segment-wise position normalization with UnitSphereNorm
    - transform: NodeSize

    # Apply sampling transforms first to reduce the number of nodes and
    # edges. These operations are compute-intensive and are the reason
    # why these transforms are not performed on CPU
    - transform: SampleSubNodes
      params:
        low: 0
        high: 1
        n_min: ${datamodule.sample_point_min}
        n_max: ${datamodule.sample_point_max}
    - transform: SampleRadiusSubgraphs
      params:
        r: ${datamodule.sample_graph_r}
        k: ${datamodule.sample_graph_k}
        i_level: 1
        by_size: False
        by_class: False
        disjoint: ${datamodule.sample_graph_disjoint}
    - transform: SampleSegments
      params:
        ratio: ${datamodule.sample_segment_ratio}
        by_size: ${datamodule.sample_segment_by_size}
        by_class: ${datamodule.sample_segment_by_class}
    - transform: NAGRestrictSize
      params:
        level: '1+'
        num_nodes: ${datamodule.max_num_nodes}

    # Cast all attributes to either float or long. Doing this only now
    # allows speeding up disk I/O and CPU->GPU transfer
    - transform: NAGCast

    # Apply geometric transforms affecting position, offsets, normals
    # before calling transforms relying on those, such as on-the-fly
    # edge features computation
    - transform: NAGJitterKey
      params:
        key: 'pos'
        sigma: ${datamodule.pos_jitter}
        trunc: ${datamodule.voxel}
    - transform: RandomTiltAndRotate
      params:
        phi: ${datamodule.tilt_n_rotate_phi}
        theta: ${datamodule.tilt_n_rotate_theta}
    - transform: RandomAnisotropicScale
      params:
        delta: ${datamodule.anisotropic_scaling}
    - transform: RandomAxisFlip
      params:
        p: 0.5

    # Compute some horizontal and vertical edges on-the-fly. Those are
    # only computed now since they can be deduced from point and node
    # attributes. Besides, the OnTheFlyHorizontalEdgeFeatures transform
    # takes a trimmed graph as input and doubles its size, creating j->i
    # for each input i->j edge
    - transform: OnTheFlyHorizontalEdgeFeatures
      params:
        keys: ${datamodule.edge_hf}
        use_mean_normal: ${eval:'"normal" in ${datamodule.segment_mean_hf}'}
    - transform: OnTheFlyVerticalEdgeFeatures
      params:
        keys: ${datamodule.v_edge_hf}
        use_mean_normal: ${eval:'"normal" in ${datamodule.segment_mean_hf}'}

    # Edge sampling is only performed after the horizontal graph is
    # untrimmed by OnTheFlyHorizontalEdgeFeatures
    - transform: SampleEdges
      params:
        level: '1+'
        n_min: ${datamodule.sample_edge_n_min}
        n_max: ${datamodule.sample_edge_n_max}
    - transform: NAGRestrictSize
      params:
        level: '1+'
        num_edges: ${datamodule.max_num_edges}

    # Move all point and segment features to 'x', except for "rgb", on
    # which we want to apply specific transforms
    - transform: NAGAddKeysTo
      params:
        level: 0
        keys: ${eval:'ListConfig([k for k in ${datamodule.point_hf} if k != "rgb"])'}
        to: 'x'
    - transform: NAGAddKeysTo
      params:
        level: '1+'
        keys: ${eval:'ListConfig([k for k in ${datamodule.segment_hf} if k != "rgb"])'}
        to: 'x'

    # Add some noise and randomly some point, node and edge features
    - transform: NAGJitterKey
      params:
        key: 'x'
        sigma: ${datamodule.node_feat_jitter}
        trunc: ${eval:'2 * ${datamodule.node_feat_jitter}'}
    - transform: NAGJitterKey
      params:
        key: 'edge_attr'
        sigma: ${datamodule.h_edge_feat_jitter}
        trunc: ${eval:'2 * ${datamodule.h_edge_feat_jitter}'}
    - transform: NAGJitterKey
      params:
        key: 'v_edge_attr'
        sigma: ${datamodule.v_edge_feat_jitter}
        trunc: ${eval:'2 * ${datamodule.v_edge_feat_jitter}'}
    - transform: NAGDropoutColumns
      params:
        p: ${datamodule.node_feat_drop}
        key: 'x'
        inplace: True
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutColumns
      params:
        p: ${datamodule.h_edge_feat_drop}
        key: 'edge_attr'
        inplace: True
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutColumns
      params:
        p: ${datamodule.v_edge_feat_drop}
        key: 'v_edge_attr'
        inplace: True
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutRows
      params:
        p: ${datamodule.node_row_drop}
        key: 'x'
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutRows
      params:
        p: ${datamodule.h_edge_row_drop}
        key: 'edge_attr'
        to_mean: ${datamodule.drop_to_mean}
    - transform: NAGDropoutRows
      params:
        p: ${datamodule.v_edge_row_drop}
        key: 'v_edge_attr'
        to_mean: ${datamodule.drop_to_mean}

    # RGB-specific transforms. In particular, the color dropout will
    # switch off all three color channels together, instead of just one
    # by one with
#    - transform: NAGColorNormalize
#      params:
#        level: 'all'
    - transform: NAGJitterKey
      params:
        key: 'rgb'
        sigma: ${datamodule.rgb_jitter}
        trunc: ${eval:'2 * ${datamodule.rgb_jitter}'}
    - transform: NAGColorAutoContrast
      params:
        p: ${datamodule.rgb_autocontrast}
    - transform: NAGColorDrop
      params:
        p: ${datamodule.rgb_drop}

    # Finally move RGB to node and segment features, if need be
    - transform: NAGAddKeysTo
      params:
        keys: 'rgb'
        to: 'x'
        strict: False

    # Add self-loops in the horizontal graph
    - transform: NAGAddSelfLoops

    # Compute the instance graph for instantiation
    # NB: setting `datamodule.instance: False` will skip this step
    - transform: OnTheFlyInstanceGraph
      params:
        level: ${eval:'1 if ${datamodule.instance} else -1'}
        num_classes: ${datamodule.num_classes}
        k_max: ${datamodule.instance_k_max}
        radius: ${datamodule.instance_radius}
        
# GPU-based val transforms
on_device_val_transform:

    # Add a `node_size` attribute to all segments, this is needed for
    # segment-wise position normalization with UnitSphereNorm
    - transform: NodeSize

    # Cast all attributes to either float or long. Doing this only now
    # allows speeding up disk I/O and CPU->GPU transfer
    - transform: NAGCast

    # Compute some horizontal and vertical edges on-the-fly. Those are
    # only computed now since they can be deduced from point and node
    # attributes. Besides, the OnTheFlyHorizontalEdgeFeatures transform
    # takes a trimmed graph as input and doubles its size, creating j->i
    # for each input i->j edge
    - transform: OnTheFlyHorizontalEdgeFeatures
      params:
        keys: ${datamodule.edge_hf}
        use_mean_normal: ${eval:'"normal" in ${datamodule.segment_mean_hf}'}
    - transform: OnTheFlyVerticalEdgeFeatures
      params:
        keys: ${datamodule.v_edge_hf}
        use_mean_normal: ${eval:'"normal" in ${datamodule.segment_mean_hf}'}

    # Move all point and segment features to 'x', except for "rgb", on
    # which we want to apply specific transforms
    - transform: NAGAddKeysTo
      params:
        level: 0
        keys: ${eval:'ListConfig([k for k in ${datamodule.point_hf} if k != "rgb"])'}
        to: 'x'
    - transform: NAGAddKeysTo
      params:
        level: '1+'
        keys: ${eval:'ListConfig([k for k in ${datamodule.segment_hf} if k != "rgb"])'}
        to: 'x'

    # RGB-specific transforms. In particular, the color dropout will
    # switch off all three color channels together, instead of just one
    # by one with
#    - transform: NAGColorNormalize
#      params:
#        level: 'all'

    # Finally move RGB to node and segment features, if need be
    - transform: NAGAddKeysTo
      params:
        keys: 'rgb'
        to: 'x'
        strict: False

    # Add self-loops in the horizontal graph
    - transform: NAGAddSelfLoops

    # Compute the instance graph for instantiation
    # NB: setting `datamodule.instance: False` will skip this step
    - transform: OnTheFlyInstanceGraph
      params:
        level: ${eval:'1 if ${datamodule.instance} else -1'}
        num_classes: ${datamodule.num_classes}
        k_max: ${datamodule.instance_k_max}
        radius: ${datamodule.instance_radius}
        
# GPU-based test transforms
on_device_test_transform: ${datamodule.on_device_val_transform}
