from utils.config_node import ConfigNode
import numpy as np
import math


config = ConfigNode()

config.device = 'cuda'
config.device_id = '0'
config.world_size = -1 #default 1, but will be change in from the envs

# cuDNN
config.cudnn = ConfigNode()
config.cudnn.benchmark = False #True
config.cudnn.deterministic = True #False

config.model = ConfigNode()
config.model.name = '0922_stage2_main_group_view_sigmoid_aux_ddp' # the main.npy name
config.model.note = 'stage2_0925_lr04_aux02'
config.dataset = ConfigNode()
config.dataset.num_workers = 6 # or 6: 1 in openai, 6 in sora

config.dataset.all_to_memory = False
config.dataset.name = 'assembly101'
config.dataset.split = 'train'
config.dataset.features_path = '/data/peiyao/data/all_data/Assembly101/TSM_features'
config.dataset.anno_data_path = "/data/peiyao/data/all_data/Assembly101/annotations/"
config.dataset.vid_list_path = config.dataset.anno_data_path + "coarse-annotations/coarse_splits/"
config.dataset.gt_path = config.dataset.anno_data_path + "coarse-annotations/coarse_labels/"
config.dataset.mapping_file = config.dataset.anno_data_path + "coarse-annotations/actions.csv"
config.dataset.VIEWS = ['C10095_rgb', 'C10115_rgb', 'C10118_rgb', 'C10119_rgb', 'C10379_rgb', 'C10390_rgb', 'C10395_rgb', 'C10404_rgb',
                         'HMC_21176875_mono10bit', 'HMC_84346135_mono10bit', 'HMC_21176623_mono10bit', 'HMC_84347414_mono10bit',
                         'HMC_21110305_mono10bit', 'HMC_84355350_mono10bit', 'HMC_21179183_mono10bit', 'HMC_84358933_mono10bit']

config.dataset.reduce_data = False #True#False
config.dataset.batch_size = 64# 20？32 #dp_512_8gpu 32 64  #1024 512, 256
config.dataset.train_win = 20 # 20
config.dataset.stride = 10 # 20
config.dataset.dilation = 1
config.dataset.win_dis = 10
config.dataset.num_in_group = 10 # 20  250 # total batchsize is : batch_size * num_in_group
config.dataset.val_win = config.dataset.train_win
config.dataset.sample_type = 'uniform_group_multiview' # uniform_group_multiview, uniform_group_pair
# config.dataset.num_workers = 6 # or 6: 1 in openai, 6 in sora
##======== for model =========
config.model.refine_fea = ConfigNode()
config.model.refine_fea.use_pair_weight = True
# assembly_charades_0922/pair/0922_stage1_main_group_pair_norm_aug_moredata_ddp/0922_lr504_logit/checkpoint_epoch_170
config.model.refine_fea.ckpt_note = 'assembly_charades_0922/pair_stage1_aux_0924_lr504_logit1/checkpoint_epoch_150.pth' #'0828_gpu8_new_bs100_group20_lr04_aug'
config.model.refine_fea.pair_weight_ckpt_path =  f'/home/peiyao/Documents/ACode/Contras_refine/Contras_fea/experiments/{config.model.refine_fea.ckpt_note}'

config.model.refine_fea.type_frame = 'AttentionPool'
assert config.model.refine_fea.type_frame in ['AvgPool', 'MaxPool', 'AttentionPool', 'ClsToken'], 'The wrong type for getting frame feature!'
config.model.refine_fea.in_channel = 2048
config.model.refine_fea.d_model = 512
config.model.refine_fea.init_logit_scale = math.log10(1/0.7) # 0.1, 0.5, 1
config.model.aggre = ConfigNode()
config.model.aggre.type_aggre = 'AvgPool'
assert config.model.aggre.type_aggre in ['AvgPool', 'MaxPool', 'AttentionPool'], 'The wrong type og aggregation!'

##===========loss==================
config.loss = ConfigNode()
config.loss.sigma = 0
config.loss.positive_weight = 1
config.loss.aux_weight = 0.2


##========== for training ================
config.train = ConfigNode()
config.train.checkpoint = ''
config.train.resume = False
config.train.log_period = 100
config.train.output_dir = 'experiments'
##

# optimizer (options: sgd, adam, lars, adabound, adaboundw)
config.train.optimizer = 'adamw'
config.train.base_lr = 1e-4#1e-4 0.0001 # 0.0001
config.train.save_epochs = [120, 170, 190]
config.train.weight_decay =  1e-4 #1e-4 # 1e-4
config.train.start_epoch = 0
config.train.epochs = 500#2000
config.train.warmup_epochs = 5
config.train.seed = 42

# distributed
config.train.distributed = True
config.train.dist = ConfigNode()
config.train.dist.backend = 'nccl' #nccl, gloo
config.train.dist.init_method = 'env://'
config.train.dist.world_size = -1
config.train.dist.node_rank = -1
config.train.dist.local_rank = 0
config.train.dist.use_sync_bn = False
##===============validation


def get_default_config():
    return config.clone()


