from utils.config_node import ConfigNode
import numpy as np
import math


config = ConfigNode()

config.device = 'cuda'
config.device_id = '0'
config.world_size = -1 #default 1, but will be change in from the envs

# cuDNN
config.cudnn = ConfigNode()
config.cudnn.benchmark = False #True
config.cudnn.deterministic = True #False

config.model = ConfigNode()
config.model.name = ''
config.model.note = '0921_sigmoid_updata_lr04' # 0921_sigmoid_updata_lr04 0911_3data_lr504 0901_gpu8_new_bs50_group20_wdis10_stride10_lr204_layer12_aug 0828_gpu8_new_bs50_group40_lr03_aug 0825_gpu8_new_bs384_lr04_aug 823_gpu8_new_bs384 0623_all_win20_str10_group20_dis20_bs128_gpu8 0603_all_win10_str10_group20_bs84_6layer_beforenorm_gpu8   '0513_all_win10_str10_dis10_group30_bs512_gpu8'# 0504_groupi_attentionpool_win5_logits1_gpu4 0425_double_attentionpool_wdis10_strid10_dilate2_logits1_gpu4  0422_double_avg_logits1_gpu4 0422_double_avg_logits1_gpu2 0419_cls_avg_gpu2 0418_avg_atten_gpu2
# 0828_gpu8_new_bs100_group20_lr04_aug

config.dataset = ConfigNode()
config.dataset.aug = True
config.dataset.all_to_memory = False

config.dataset.dataset_all = ['Assembly101', 'CharadesEgo']#,  'CharadesEgo_1',  'CharadesEgo_2',  'CharadesEgo_3'] #['Assembly101', 'CharadesEgo']
config.dataset.reduce_data = False
config.dataset.batch_size = 200#200 #128 84 #dp_512_8gpu 32 64  #1024 512, 256
config.dataset.train_win = 20 # 20
config.dataset.stride = 10 # 20
config.dataset.dilation = 1
config.dataset.win_dis = 10
config.dataset.num_in_group = 10 # 250 # total batchsize is : batch_size * num_in_group
config.dataset.val_win = config.dataset.train_win
config.dataset.sample_type = 'uniform_group_pair' # uniform_group_multiview, uniform_group_pair

config.dataset.assembly101 = ConfigNode()
config.dataset.assembly101.name = 'assembly101'
config.dataset.assembly101.split = 'train'
config.dataset.assembly101.features_path = '/data/peiyao/data/all_data/Assembly101/TSM_features'
config.dataset.assembly101.anno_data_path = "/data/peiyao/data/all_data/Assembly101/annotations/"
config.dataset.assembly101.vid_list_path = config.dataset.assembly101.anno_data_path + "coarse-annotations/coarse_splits/"
config.dataset.assembly101.gt_path = config.dataset.assembly101.anno_data_path + "coarse-annotations/coarse_labels/"
config.dataset.assembly101.mapping_file = config.dataset.assembly101.anno_data_path + "coarse-annotations/actions.csv"
config.dataset.assembly101.VIEWS = ['C10095_rgb', 'C10115_rgb', 'C10118_rgb', 'C10119_rgb', 'C10379_rgb', 'C10390_rgb', 'C10395_rgb', 'C10404_rgb',
             'HMC_21176875_mono10bit', 'HMC_84346135_mono10bit', 'HMC_21176623_mono10bit', 'HMC_84347414_mono10bit',
             'HMC_21110305_mono10bit', 'HMC_84355350_mono10bit', 'HMC_21179183_mono10bit', 'HMC_84358933_mono10bit']

config.dataset.charades = ConfigNode()
config.dataset.charades.name = 'CharadesEgo'
config.dataset.charades.split = 'train'
config.dataset.charades.features_path = '/data/peiyao/data/all_data/Charades-Ego/CharadesEgo_v1_480_fea_npy'
config.dataset.charades.statistic_file = 'data/charades_statistic_input.txt'

config.dataset.charades_1 = ConfigNode()
config.dataset.charades_1.split = 'train'
config.dataset.charades_1.features_path = '/data/peiyao/data/all_data/Charades-Ego/CharadesEgo_v1_480_fea_npy_1'
config.dataset.charades_1.statistic_file = 'data/charades_statistic_input_1.txt'

config.dataset.charades_2 = ConfigNode()
config.dataset.charades_2.split = 'train'
config.dataset.charades_2.features_path = '/data/peiyao/data/all_data/Charades-Ego/CharadesEgo_v1_480_fea_npy_2'
config.dataset.charades_2.statistic_file = 'data/charades_statistic_input_2.txt'

config.dataset.charades_3 = ConfigNode()
config.dataset.charades_3.split = 'train'
config.dataset.charades_3.features_path = '/data/peiyao/data/all_data/Charades-Ego/CharadesEgo_v1_480_fea_npy_3'
config.dataset.charades_3.statistic_file = 'data/charades_statistic_input_3.txt'

##======== for model =========
config.model.refine_fea = ConfigNode()
config.model.refine_fea.type_frame = 'AttentionPool'
assert config.model.refine_fea.type_frame in ['AvgPool', 'MaxPool', 'AttentionPool', 'ClsToken'], 'The wrong type for getting frame feature!'
config.model.refine_fea.in_channel = 2048
config.model.refine_fea.d_model = 512
config.model.refine_fea.init_logit_scale = math.log10(1/0.7) #, 1

##========== for training ================
config.train = ConfigNode()
config.train.checkpoint = ''
config.train.resume = False
config.train.log_period = 2000
config.train.output_dir = 'experiments'
##

# optimizer (options: sgd, adam, lars, adabound, adaboundw)
config.train.optimizer = 'adamw'
config.train.base_lr = 5e-4#1e-4 0.0001 # 0.0001
config.train.weight_decay =  1e-4 #1e-4 # 1e-4
config.train.start_epoch = 0
config.train.epochs = 400
config.train.warmup_epochs = 5
config.train.seed = 42

# distributed
config.train.distributed = True
config.train.dist = ConfigNode()
config.train.dist.backend = 'nccl' #nccl, gloo
config.train.dist.init_method = 'env://'
config.train.dist.world_size = -1
config.train.dist.node_rank = -1
config.train.dist.local_rank = 0
config.train.dist.use_sync_bn = False
##===============validation



def get_default_config_moredata():
    return config.clone()


