from easydict import EasyDict as edict
# import yaml
# import pdb
import os


"""
default config
"""
cfg = edict()
cfg.BATCH_SIZE = 4 # default 4
cfg.LAMBDA_1 = 0.5 # default: 0.5
cfg.MASK_NUM = 10 # 10 for fully supervised
cfg.NUM_CLASSES = 71 # 70 + 1 background

###############################
# # TRAIN
# cfg.TRAIN = edict()

# cfg.TRAIN.FREEZE_AUDIO_EXTRACTOR = True
# cfg.TRAIN.PRETRAINED_VGGISH_MODEL_PATH = "./torchvggish/vggish-10086976.pth"
# cfg.TRAIN.PREPROCESS_AUDIO_TO_LOG_MEL = True #! notice
# cfg.TRAIN.POSTPROCESS_LOG_MEL_WITH_PCA = False
# cfg.TRAIN.PRETRAINED_PCA_PARAMS_PATH = "./torchvggish/vggish_pca_params-970ea276.pth"
# cfg.TRAIN.FREEZE_VISUAL_EXTRACTOR = True
# cfg.TRAIN.PRETRAINED_RESNET50_PATH = "../../pretrained_backbones/resnet50-19c8e357.pth"
# cfg.TRAIN.PRETRAINED_PVTV2_PATH = "../../pretrained_backbones/pvt_v2_b5.pth"

# cfg.TRAIN.FINE_TUNE_SSSS = False
# cfg.TRAIN.PRETRAINED_S4_AVS_WO_TPAVI_PATH = "../single_source_scripts/logs/ssss_20220118-111301/checkpoints/checkpoint_29.pth.tar"
# cfg.TRAIN.PRETRAINED_S4_AVS_WITH_TPAVI_PATH = "../single_source_scripts/logs/ssss_20220118-112809/checkpoints/checkpoint_68.pth.tar"

###############################
# DATA
ROOT = os.path.join('..', 'avsdata')
cfg.DATA = edict()
cfg.DATA.CROP_IMG_AND_MASK = True
cfg.DATA.CROP_SIZE = 224 # short edge

cfg.DATA.META_CSV_PATH = os.path.join(ROOT, "metadata.csv") #! notice: you need to change the path
cfg.DATA.LABEL_IDX_PATH = os.path.join(ROOT, "label2idx.json") #! notice: you need to change the path

cfg.DATA.DIR_BASE = ROOT #! notice: you need to change the path
# cfg.DATA.DIR_MASK = "../../avsbench_data/v2_data/gt_masks" #! notice: you need to change the path
# cfg.DATA.DIR_COLOR_MASK = "../../avsbench_data/v2_data/gt_color_masks_rgb" #! notice: you need to change the path
cfg.DATA.IMG_SIZE = (224, 224)
###############################
cfg.DATA.RESIZE_PRED_MASK = True
cfg.DATA.SAVE_PRED_MASK_IMG_SIZE = (360, 240) # (width, height)


######
# V3 splition
cfg.DATA.META_SEEN_TRAIN_PATH = os.path.join(ROOT, "v3", "meta_v3_seen_train.csv")
cfg.DATA.META_SEEN_VAL_PATH = os.path.join(ROOT, "v3", "meta_v3_seen_val.csv")
cfg.DATA.META_UNSEEN_PATH = os.path.join(ROOT, "v3", "meta_v3_unseen.csv")
cfg.DATA.LAYER_FEAT_PATH = os.path.join('..', 'features')


# #####
# Original S4 data
ROOT_S4 = os.path.join(ROOT, 'avsbench_data', 'Single-source')
cfg.DATA.ANNO_CSV4 = os.path.join(ROOT_S4, 's4_meta_data.csv')
cfg.DATA.DIR_IMG4 = os.path.join(ROOT_S4, 's4_data', 'visual_frames')
cfg.DATA.DIR_AUDIO4 = os.path.join(ROOT_S4, 's4_data', 'audio_wav')
cfg.DATA.DIR_MASK4 = os.path.join(ROOT_S4, 's4_data', 'gt_masks')

#####
# Original MS3 data avsdata/avsbench_data/Multi-sources
ROOT_MS3 = os.path.join(ROOT, 'avsbench_data', 'Multi-sources')
# cfg.DATA.ANNO_CSV3 = os.path.join(ROOT_MS3, 'ms3_meta_data.csv')
cfg.DATA.ANNO_CSV3 = os.path.join('..', 'meta_ms3_label.csv')
cfg.DATA.DIR_IMG3 = os.path.join(ROOT_MS3, 'ms3_data', 'visual_frames')
cfg.DATA.DIR_AUDIO3 = os.path.join(ROOT_MS3, 'ms3_data', 'audio_wav')
cfg.DATA.DIR_MASK3 = os.path.join(ROOT_MS3, 'ms3_data', 'gt_masks')
