MODEL_CONFIG = {
    "cv": {"ts": False, "cwt": True, "cwtp": False, "nlp": False}, 
    "vit": {"ts": False, "cwt": False, "cwtp": True, "nlp": False}, 
    "chronos": {"ts": True, "cwt": False, "cwtp": False, "nlp": False}, 
    "chronos_base": {"ts": True, "cwt": False, "cwtp": False, "nlp": False}, 
    "chronos_msitf": {"ts": True, "cwt": False, "cwtp": False, "nlp": False}, 
    "moirai": {"ts": True, "cwt": False, "cwtp": False, "nlp": False}, 
    "pretrain": {"ts": False, "cwt": False, "cwtp": True, "nlp": False}, 
    "nlp": {"ts": False, "cwt": False, "cwtp": False, "nlp": True}, 
    "fusion": {"ts": False, "cwt": False, "cwtp": True, "nlp": False}, 
    "crossvit": {"ts": False, "cwt": False, "cwtp": True, "nlp": False}, 
    "fusion_crossvit": {"ts": False, "cwt": False, "cwtp": True, "nlp": False}, 
    "mae_msitf": {"ts": False, "cwt": False, "cwtp": True, "nlp": False}, 
    "mae": {"ts": False, "cwt": False, "cwtp": True, "nlp": False}, 
}

DATASET_CONFIG = {
    "wesad": {"n_ch": 10, "n_cl": 3, "task": "class", "lr": 1e-3},
    "gameemo": {"n_ch": 4, "n_cl": 4, "task": "class", "lr": 1e-3},
    "uci_har": {"n_ch": 6, "n_cl": 6, "task": "class", "lr": 1e-2},
    "non_invasive_bp": {"n_ch": 3, "n_cl": 2, "task": "reg", "lr": 1e-1},
    "ppg_hgb": {"n_ch": 2, "n_cl": 1, "task": "reg", "lr": 1e-1},
    "ecg_heart_cat": {"n_ch": 1, "n_cl": 2, "task": "class", "lr": 1e-2},
    "PPG_HTN": {"n_ch": 1, "n_cl": 4, "task": "class", "lr": 1e-2}, # 1e-2
    "PPG_DM": {"n_ch": 1, "n_cl": 2, "task": "class", "lr": 1e-2},
    "PPG_CVA": {"n_ch": 1, "n_cl": 2, "task": "class", "lr": 1e-2},
    "PPG_CVD": {"n_ch": 1, "n_cl": 3, "task": "class",  "lr": 1e-2},
    "drive_fatigue": {"n_ch": 4, "n_cl": 2, "task": "class",  "lr": 1e-2},
    "indian-fPCG": {"n_ch": 1, "n_cl": 1, "task": "reg",  "lr": 1e-2},
}

'''
For pretrain:
PPG_BP:  lr=1e-1, epochs=100, ss=10
PPG_HGB: lr=1e-1, epochs=100, ss=10
PPG_HTN: lr=1e-2, epochs=50, ss=10
ECG_Cat: lr=1e-2, epochs=50, ss=10
GAMEEMO: lr=1e-2, epochs=50, ss=10
UCI_HAR: lr=1e-2, epochs=50, ss=10
'''

# # fix data path error in ppg_china dataset
# import pickle
# if __name__ == "__main__":
#     for dx in ["DM", "CVA", "CVD"]:
#         # fetch old split
#         with open("data/PPG_{}/splits".format(dx), 'rb') as f:
#             split = pickle.load(f)
        
#         # replace HTN to correct name
#         for k in split:
#             for i in range(len(split[k])):
#                 split[k][i] = split[k][i].replace("PPG_HTN", "PPG_{}".format(dx))
        
#         # overwrite
#         with open("data/PPG_{}/splits".format(dx), 'wb') as f:
#             pickle.dump(split, f)