from config import parse_args
import torch
from DNN import EnsembleTransition


def load_model(arg):
    dimensions = [arg.observation_dims, args.action_dims]
    hidden_features = args.hidden_feature_number
    hidden_layers = args.hidden_layer_number
    checkpoint_train = torch.load(arg.checkpoint_paths[0])
    checkpoint_test = torch.load(arg.checkpoint_paths[1])

    outer_dnn = EnsembleTransition(args, dimensions[0], dimensions[1], hidden_features, hidden_layers,
                                   device=torch.device(args.device), cuda_index=args.cuda_index)
    outer_dnn.load_state_dict(checkpoint_train['meta_model'])
    inner_dnns = []
    for i in range(arg.train_task_number):
        inner_dnns.append(EnsembleTransition(args, dimensions[0], dimensions[1], hidden_features, hidden_layers,
                                             device=torch.device(args.device)))
        inner_dnns[i].load_state_dict(checkpoint_train['training_task_models'][i])
    test_dnns = []
    for i in range(arg.trj_number - arg.train_task_number):
        test_dnns.append(EnsembleTransition(args, dimensions[0], dimensions[1], hidden_features, hidden_layers,
                                            device=torch.device(args.device)))
        inner_dnns[i].load_state_dict(checkpoint_test['testing_trained_task_models'][i])
    print('Models are loaded.')
    return inner_dnns, test_dnns, outer_dnn


if __name__ == '__main__':
    args = parse_args()
    load_model(args)
    print('Finished')