import os
import subprocess
import itertools
import yaml
import json
import wandb
import pickle
import glob
import shutil


is_3d = False
if is_3d:
    backbone = 'cusResNet18_3d'
    batch_size = 8

else:
    backbone = 'cusResNet18'
    batch_size = 1024

methods = ['baseline', 'resampling', 'LAFTR', 'CFair', 'LNL', 'EnD', 'DomainInd', 'GroupDRO', 'ODR']
#methods = ['baseline']
sensitive_name = 'Age'
dataset_name = 'MIMIC_CXR'
output_dim = num_classes = 14
sens_classes = 2
if sens_classes == 2:
    methods = ['baseline', 'resampling', 'LAFTR', 'CFair', 'LNL', 'EnD', 'DomainInd', 'GroupDRO', 'ODR']
else:
    methods = ['baseline', 'resampling', 'LNL', 'EnD', 'DomainInd', 'GroupDRO', 'ODR']
    #methods = ['baseline']
    

pickle_path = './sweep/test/hash/{datas}/{attr}/{datas}-{attr}.pkl'.format(
        datas = dataset_name, attr = sensitive_name)
with open(pickle_path, 'rb') as handle:
    method_hash = pickle.load(handle)




model_path = './model_records/{datas}/{attr}/{bkbone}/'.format(
        datas = dataset_name, attr = sensitive_name, bkbone = backbone)

for method, hash_id in method_hash.items():
    
    print(method)
    method_model_path = os.path.join(model_path, method)
    models = glob.glob(method_model_path + '/{}_*.pth'.format(hash_id))
    #print(models)
    
    for model in models:
        new_name = 'cross_domain_' + model.split('/')[-1]
        new_path = os.path.join(model_path, method, new_name)
        #print(new_path)
        shutil.copy(model, new_path)
        
