from pathlib import Path
from types import SimpleNamespace
from copy import deepcopy
from torch import nn
import torch

# ScanNetV2 dataset path
# should contain scans/
raw_data_path = Path(".../")

processed_data_path = raw_data_path.parent / "scannetv2"
# if you want to set the processed dataset path, uncomment here
#processed_data_path = Path("")

scan_train = Path(__file__).parent / "scannetv2_train.txt"
scan_val = Path(__file__).parent / "scannetv2_val.txt"
with open(scan_train, 'r') as file:
    scan_train = [line.strip() for line in file.readlines()]
with open(scan_val, 'r') as file:
    scan_val = [line.strip() for line in file.readlines()]

epoch = 100
warmup = 10
batch_size = 6
learning_rate = 6e-3  #6e-3
label_smoothing = 0.2

scan_args = SimpleNamespace()
scan_args.k = [24, 32, 32, 32, 32]
scan_args.grid_size = [0.02, 0.04, 0.08, 0.16, 0.32]

scan_args.max_pts = 80000

scan_warmup_args = deepcopy(scan_args)
scan_warmup_args.grid_size = [0.02, 1.8, 3.5, 3.5, 4]

dsconv_args = SimpleNamespace()
dsconv_args.ks = scan_args.k
dsconv_args.depths = [4, 4, 4, 8, 8]
dsconv_args.dims = [80, 128, 192, 352, 640]
dsconv_args.nbr_dims = [32, 32]
dsconv_args.head_dim = 352
dsconv_args.num_classes = 20
drop_path = 0.1
drop_rates = torch.linspace(0., drop_path, sum(dsconv_args.depths)).split(dsconv_args.depths)
dsconv_args.drop_paths = [dpr.tolist() for dpr in drop_rates]
dsconv_args.head_drops = torch.linspace(0., 0.2, len(dsconv_args.depths)).tolist()
dsconv_args.bn_momentum = 0.02
dsconv_args.act = nn.GELU
dsconv_args.mlp_ratio = 2
# gradient checkpoint
dsconv_args.use_cp = False

dsconv_args.cor_std = [1.6, 3.0, 5.7, 11.2, 21.5]

dsconv_args.all_dist = torch.load("dist1.pt")
dsconv_args.all_dist0 = torch.load("dist0.pt")
