import torch
import argparse
import os

parser = argparse.ArgumentParser(description='Transfer best model state dict to prior dir')
parser.add_argument('--input_dir', type=str)

args = parser.parse_args()

origin_dict = torch.load(args.input_dir+'model.pkl')
param = origin_dict['args']
model_dict = origin_dict['model_dict']

algorithm = param['algorithm']
dataset = param['dataset']
test_env = param['test_envs']
print(test_env)

# if len(test_env)==1:
#     output_dir = os.path.join('./prior_dicts', algorithm,dataset,str(test_env[0]))
#     os.makedirs(output_dir, exist_ok=True)

#     featurizer_dict = {}
#     classifier_dict = {}
#     for i in model_dict:
#         if 'featurizer.network.' in i:
#             featurizer_dict[i.replace('featurizer.','')] = model_dict[i]
#         elif 'classifier.' in i:
#             classifier_dict[i.replace('classifier.','')] = model_dict[i]
#     torch.save(featurizer_dict,os.path.join(output_dir,'featurizer_dict.pkl'))
#     torch.save(classifier_dict,os.path.join(output_dir,'classifier_dict.pkl'))
#     print('success')
# else:
#     print('fail')

output_dir = os.path.join('./prior_dicts', algorithm,dataset,str(test_env))
os.makedirs(output_dir, exist_ok=True)

featurizer_dict = {}
classifier_dict = {}
for i in model_dict:
    if 'featurizer.network.' in i:
        featurizer_dict[i.replace('featurizer.','')] = model_dict[i]
    elif 'classifier.' in i:
        classifier_dict[i.replace('classifier.','')] = model_dict[i]
torch.save(featurizer_dict,os.path.join(output_dir,'featurizer_dict.pkl'))
torch.save(classifier_dict,os.path.join(output_dir,'classifier_dict.pkl'))
print('success')
print("featurizer:")
for i in featurizer_dict:print(i)
print("classifier:")
for i in classifier_dict:print(i)