from utils.config_node import ConfigNode
import numpy as np
import math


config = ConfigNode()

config.device = 'cuda'
config.device_id = '0'
config.world_size = -1 #default 1, but will be change in from the envs

# cuDNN
config.cudnn = ConfigNode()
config.cudnn.benchmark = False #True
config.cudnn.deterministic = True #False

config.model = ConfigNode()
config.model.name = ''
config.model.note = '0606_all_win10_str10_group10_bs512_vit_gpu8' #0603_all_win10_str10_group20_bs84_6layer_beforenorm_gpu8   '0513_all_win10_str10_dis10_group30_bs512_gpu8'# 0504_groupi_attentionpool_win5_logits1_gpu4 0425_double_attentionpool_wdis10_strid10_dilate2_logits1_gpu4  0422_double_avg_logits1_gpu4 0422_double_avg_logits1_gpu2 0419_cls_avg_gpu2 0418_avg_atten_gpu2

config.dataset = ConfigNode()
config.dataset.all_to_memory = False
config.dataset.name = 'assembly101'
config.dataset.split = 'train'
config.dataset.features_path = '/data/peiyao/data/all_data/Assembly101/TSM_features'
config.dataset.anno_data_path = "/data/peiyao/data/all_data/Assembly101/annotations/"
config.dataset.vid_list_path = config.dataset.anno_data_path + "coarse-annotations/coarse_splits/"
config.dataset.gt_path = config.dataset.anno_data_path + "coarse-annotations/coarse_labels/"
config.dataset.mapping_file = config.dataset.anno_data_path + "coarse-annotations/actions.csv"
config.dataset.VIEWS = ['C10095_rgb', 'C10115_rgb', 'C10118_rgb', 'C10119_rgb', 'C10379_rgb', 'C10390_rgb', 'C10395_rgb', 'C10404_rgb',
             'HMC_21176875_mono10bit', 'HMC_84346135_mono10bit', 'HMC_21176623_mono10bit', 'HMC_84347414_mono10bit',
             'HMC_21110305_mono10bit', 'HMC_84355350_mono10bit', 'HMC_21179183_mono10bit', 'HMC_84358933_mono10bit']

config.dataset.reduce_data = False
config.dataset.batch_size = 512# 84 #dp_512_8gpu 32 64  #1024 512, 256
config.dataset.train_win = 10 # 20
config.dataset.stride = 10 # 20
config.dataset.dilation = 1
config.dataset.win_dis = 1
config.dataset.num_in_group = 10 # 250 # total batchsize is : batch_size * num_in_group
config.dataset.val_win = config.dataset.train_win
config.dataset.sample_type = 'uniform_group_pair' # uniform_group_multiview, uniform_group_pair

##======== for model =========
config.model.refine_fea = ConfigNode()
config.model.refine_fea.type_frame = 'cls'
assert config.model.refine_fea.type_frame in ['cls', 'mean'], 'The wrong type for getting frame feature!'
config.model.refine_fea.in_channel = 2048
config.model.refine_fea.d_model = 512
config.model.refine_fea.init_logit_scale = 1 # math.log10(1/0.7), 1
config.model.refine_fea.num_patch = config.dataset.train_win

##========== for training ================
config.train = ConfigNode()
config.train.checkpoint = ''
config.train.resume = False
config.train.log_period = 100
config.train.output_dir = 'experiments'
##

# optimizer (options: sgd, adam, lars, adabound, adaboundw)
config.train.optimizer = 'adamw'
config.train.base_lr = 1e-4#1e-4 0.0001 # 0.0001
config.train.weight_decay =  1e-4 #1e-4 # 1e-4
config.train.start_epoch = 0
config.train.epochs = 2000
config.train.warmup_epochs = 5
config.train.seed = 42

# distributed
config.train.distributed = True
config.train.dist = ConfigNode()
config.train.dist.backend = 'nccl' #nccl, gloo
config.train.dist.init_method = 'env://'
config.train.dist.world_size = -1
config.train.dist.node_rank = -1
config.train.dist.local_rank = 0
config.train.dist.use_sync_bn = False
##===============validation



def get_default_config_vit():
    return config.clone()


