from utils.config_node import ConfigNode
import numpy as np
import math


config = ConfigNode()

config.device = 'cuda'
config.device_id = '0'
config.world_size = -1 #default 1, but will be change in from the envs

# cuDNN
config.cudnn = ConfigNode()
config.cudnn.benchmark = False #True
config.cudnn.deterministic = True #False

config.model = ConfigNode()
# config.model.name = '0922_stage1_main_group_pair_norm_aug_moredata_ddp' # the main.py name
config.model.name = '0922_stage1_main_group_pair_norm_aug_moredata_nogather_ddp' # the main.py name
config.model.note = 'stage1_0924_lr504_logit1_nogather' #  0922_lr504_logit
config.dataset = ConfigNode()
config.dataset.num_workers = 6 # or 6: 1 in openai, 6 in sora

config.dataset.aug = True
config.dataset.all_to_memory = False

config.dataset.dataset_all = ['Assembly101', 'CharadesEgo']#,  'CharadesEgo_1',  'CharadesEgo_2',  'CharadesEgo_3'] #['Assembly101', 'CharadesEgo']
config.dataset.reduce_data = False
config.dataset.batch_size = 200#200 #128 84 #dp_512_8gpu 32 64  #1024 512, 256
config.dataset.train_win = 20 # 20
config.dataset.stride = 10 # 20
config.dataset.dilation = 1
config.dataset.win_dis = 10
config.dataset.num_in_group = 10 # 250 # total batchsize is : batch_size * num_in_group
config.dataset.val_win = config.dataset.train_win
config.dataset.sample_type = 'uniform_group_pair' # uniform_group_multiview, uniform_group_pair

config.dataset.assembly101 = ConfigNode()
config.dataset.assembly101.name = 'assembly101'
config.dataset.assembly101.split = 'train'
config.dataset.assembly101.features_path = '/data/peiyao/data/all_data/Assembly101/TSM_features'
config.dataset.assembly101.anno_data_path = "/data/peiyao/data/all_data/Assembly101/annotations/"
config.dataset.assembly101.vid_list_path = config.dataset.assembly101.anno_data_path + "coarse-annotations/coarse_splits/"
config.dataset.assembly101.gt_path = config.dataset.assembly101.anno_data_path + "coarse-annotations/coarse_labels/"
config.dataset.assembly101.mapping_file = config.dataset.assembly101.anno_data_path + "coarse-annotations/actions.csv"
config.dataset.assembly101.VIEWS = ['C10095_rgb', 'C10115_rgb', 'C10118_rgb', 'C10119_rgb', 'C10379_rgb', 'C10390_rgb', 'C10395_rgb', 'C10404_rgb',
             'HMC_21176875_mono10bit', 'HMC_84346135_mono10bit', 'HMC_21176623_mono10bit', 'HMC_84347414_mono10bit',
             'HMC_21110305_mono10bit', 'HMC_84355350_mono10bit', 'HMC_21179183_mono10bit', 'HMC_84358933_mono10bit']

config.dataset.charades = ConfigNode()
config.dataset.charades.name = 'CharadesEgo'
config.dataset.charades.split = 'train'
config.dataset.charades.features_path = '/data/peiyao/data/all_data/Charades-Ego/CharadesEgo_v1_480_fea_npy'
config.dataset.charades.statistic_file = 'data/charades_statistic_input.txt'

##======== for model =========
config.model.refine_fea = ConfigNode()
config.model.refine_fea.type_frame = 'AttentionPool'
assert config.model.refine_fea.type_frame in ['AvgPool', 'MaxPool', 'AttentionPool', 'ClsToken'], 'The wrong type for getting frame feature!'
config.model.refine_fea.in_channel = 2048
config.model.refine_fea.d_model = 512
config.model.refine_fea.init_logit_scale = 1 #math.log10(1/0.7) #, 1

##========== for training ================
config.train = ConfigNode()
config.train.checkpoint = ''
config.train.resume = False
config.train.log_period = 2000
config.train.output_dir = 'experiments'
##

# optimizer (options: sgd, adam, lars, adabound, adaboundw)
config.train.optimizer = 'adamw'
config.train.base_lr = 5e-4#1e-4 0.0001 # 0.0001
config.train.weight_decay =  1e-4 #1e-4 # 1e-4
config.train.start_epoch = 0
config.train.epochs = 400
config.train.save_epochs = [150, 170, 190]
config.train.warmup_epochs = 5
config.train.seed = 42

# distributed
config.train.distributed = True
config.train.dist = ConfigNode()
config.train.dist.backend = 'nccl' #nccl, gloo
config.train.dist.init_method = 'env://'
config.train.dist.world_size = -1
config.train.dist.node_rank = -1
config.train.dist.local_rank = 0
config.train.dist.use_sync_bn = False
##===============validation



def get_default_config_moredata():
    return config.clone()


