from pathlib import Path
from types import SimpleNamespace
from copy import deepcopy
from torch import nn
import torch

# S3DIS dataset path
raw_data_path = Path(".../Stanford3dDataset_v1.2_Aligned_Version")

processed_data_path = raw_data_path.parent / "s3dis"
# if you want to set the processed dataset path, uncomment here
#processed_data_path = Path("")

epoch = 150
warmup = 10
batch_size = 8
learning_rate = 6e-3
label_smoothing = 0.2

s3dis_args = SimpleNamespace()
s3dis_args.k = [24, 32, 32, 32]
s3dis_args.grid_size = [0.04, 0.08, 0.16, 0.32]

s3dis_args.max_pts = 30000  #30000

s3dis_warmup_args = deepcopy(s3dis_args)
s3dis_warmup_args.grid_size = [0.04, 3.5, 3.5, 3.5]

dsconv_args = SimpleNamespace()
dsconv_args.ks = s3dis_args.k
dsconv_args.depths = [4, 4, 8, 8]
dsconv_args.dims = [64, 128, 256, 512]
dsconv_args.head_dim = 256
dsconv_args.nbr_dims = [32, 32]
dsconv_args.num_classes = 13
drop_path = 0.1
drop_rates = torch.linspace(0., drop_path, sum(dsconv_args.depths)).split(dsconv_args.depths)
dsconv_args.drop_paths = [dpr.tolist() for dpr in drop_rates]
dsconv_args.head_drops = torch.linspace(0., 0.15, len(dsconv_args.depths)).tolist()
dsconv_args.bn_momentum = 0.02
dsconv_args.act = nn.GELU
dsconv_args.mlp_ratio = 2
# gradient checkpoint
dsconv_args.use_cp = False

dsconv_args.cor_std = [1.6, 3.8, 7.6, 15.2]

dsconv_args.all_dist = torch.load("dist1.pt")
dsconv_args.all_dist0 = torch.load("dist0.pt")
