from utils.config_node import ConfigNode
import numpy as np
import math


config = ConfigNode()

config.device = 'cuda'
config.device_id = '0'
config.world_size = -1 #default 1, but will be change in from the envs

# cuDNN
config.cudnn = ConfigNode()
config.cudnn.benchmark = False #True
config.cudnn.deterministic = True #False

config.model = ConfigNode()
config.model.name = ''
config.model.note = 'test'#0902_win20_str10_wdis10_group20_bs32_gpu8 0902_win20_str10_wdis10_group20_bs10_gpu8 '0513_all_win10_str10_dis10_group30_bs512_gpu8'# 0504_groupi_attentionpool_win5_logits1_gpu4 0425_double_attentionpool_wdis10_strid10_dilate2_logits1_gpu4  0422_double_avg_logits1_gpu4 0422_double_avg_logits1_gpu2 0419_cls_avg_gpu2 0418_avg_atten_gpu2
                                        #0408_view_reduce_uniform_avg_win10_sl10_dia1_wdis1_group10_bs128_lr04_gpu8

config.dataset = ConfigNode()
config.dataset.all_to_memory = False
config.dataset.name = 'assembly101'
config.dataset.split = 'train'
config.dataset.features_path = '/data/peiyao/data/all_data/Assembly101/TSM_features'
config.dataset.anno_data_path = "/data/peiyao/data/all_data/Assembly101/annotations/"
config.dataset.vid_list_path = config.dataset.anno_data_path + "coarse-annotations/coarse_splits/"
config.dataset.gt_path = config.dataset.anno_data_path + "coarse-annotations/coarse_labels/"
config.dataset.mapping_file = config.dataset.anno_data_path + "coarse-annotations/actions.csv"
config.dataset.VIEWS = ['C10095_rgb', 'C10115_rgb', 'C10118_rgb', 'C10119_rgb', 'C10379_rgb', 'C10390_rgb', 'C10395_rgb', 'C10404_rgb',
                         'HMC_21176875_mono10bit', 'HMC_84346135_mono10bit', 'HMC_21176623_mono10bit', 'HMC_84347414_mono10bit',
                         'HMC_21110305_mono10bit', 'HMC_84355350_mono10bit', 'HMC_21179183_mono10bit', 'HMC_84358933_mono10bit']

config.dataset.reduce_data = False
config.dataset.batch_size = 2#32 #dp_512_8gpu 32 64  #1024 512, 256

config.dataset.train_win = 20 # 20
config.dataset.stride = 10 # 20
config.dataset.dilation = 1
config.dataset.win_dis = 10
config.dataset.num_in_group = 20 # 250 # total batchsize is : batch_size * num_in_group
config.dataset.val_win = config.dataset.train_win
config.dataset.sample_type = 'uniform_group_multiview' # uniform_group_multiview, uniform_group_pair

##======== for model =========
config.model.refine_fea = ConfigNode()
config.model.refine_fea.use_pair_weight = True
config.model.refine_fea.ckpt_note = '0828_gpu8_new_bs100_group20_lr04_aug'
config.model.refine_fea.detail_epoch = 'checkpoint_epoch_500.pth' #'checkpoint_current.pth'

config.model.refine_fea.pair_weight_ckpt_path =  f'/home/peiyao/Documents/ACode/Contras_refine/Contras_fea/experiments/assembly101/pair/{config.model.refine_fea.ckpt_note}/'+ config.model.refine_fea.detail_epoch
config.model.refine_fea.type_frame = 'AttentionPool'
assert config.model.refine_fea.type_frame in ['AvgPool', 'MaxPool', 'AttentionPool', 'ClsToken'], 'The wrong type for getting frame feature!'
config.model.refine_fea.in_channel = 2048
config.model.refine_fea.d_model = 512
config.model.refine_fea.init_logit_scale = 1 # math.log10(1/0.7), 1
config.model.aggre = ConfigNode()
config.model.aggre.type_aggre = 'AvgPool'
assert config.model.aggre.type_aggre in ['AvgPool', 'MaxPool', 'AttentionPool'], 'The wrong type og aggregation!'

##========== for training ================
config.train = ConfigNode()
config.train.checkpoint = ''
config.train.resume = False
config.train.log_period = 100
config.train.output_dir = 'experiments'
##

# optimizer (options: sgd, adam, lars, adabound, adaboundw)
config.train.optimizer = 'adamw'
config.train.base_lr = 1e-5#1e-4 0.0001 # 0.0001
config.train.weight_decay =  1e-4 #1e-4 # 1e-4
config.train.start_epoch = 0
config.train.epochs = 2000
config.train.warmup_epochs = 5
config.train.seed = 42

# distributed
config.train.distributed = True
config.train.dist = ConfigNode()
config.train.dist.backend = 'gloo' #nccl, gloo
config.train.dist.init_method = 'env://'
config.train.dist.world_size = -1
config.train.dist.node_rank = -1
config.train.dist.local_rank = 0
config.train.dist.use_sync_bn = False
##===============validation


def get_default_config():
    return config.clone()


