import torch 

from nn_models_hnn import MLP
from hnn import HNN

from nn_models_scnn import pq_PQ, P_H
from scnn import Generator


def benchmark_loader(args,baseline=False):
    output_dim = args.input_dim if baseline else 2
    nn_model = MLP(args.input_dim, args.hidden_dim, output_dim)
    model = HNN(args.input_dim, differentiable_model=nn_model, baseline=baseline)
    
    case = 'baseline' if baseline else 'hnn'
    path = "{}-orbits-{}.tar".format(args.name, case)
    model.load_state_dict(torch.load(path))
    return model

def pnn_loader(args, test_dim=4, HPQ_trainable=True, number=0, num_hidden=2, angular_momentum=False, momentum=False):
    nn_model_pq_PQ = pq_PQ(args.input_dim, args.hidden_dim,latent_dim_p=args.input_dim-test_dim,latent_dim_q=test_dim ,num_hidden=num_hidden,
        HPQ_trainable=HPQ_trainable, angular_momentum=angular_momentum, momentum=momentum)
    nn_model_HPQ= P_H(input_dim=args.input_dim, hidden_dim= args.hidden_dim,latent_dim_p=args.input_dim-test_dim, test_dim=test_dim, HPQ_trainable=HPQ_trainable)
    model = Generator(args.input_dim, differentiable_model_1=nn_model_pq_PQ, differentiable_model_2=nn_model_HPQ, test_dim=test_dim)
    
    path = "{}-weights-{}.tar".format(args.name, number)
    model.load_state_dict(torch.load(path))
    return model