

import torch, argparse
import numpy as np

import os
THIS_DIR = os.path.dirname(os.path.abspath(__file__))


from nn_models_scnn import pq_PQ, P_H
from scnn import Generator
from data import get_dataset
from utils import L2_loss, to_pickle

def get_args():
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument('--input_dim', default=2*4, type=int, help='dimensionality of input tensor')
    parser.add_argument('--hidden_dim', default=200, type=int, help='hidden dimension of mlp')
    parser.add_argument('--learn_rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--batch_size', default=200, type=int, help='batch_size')
    parser.add_argument('--total_steps', default=50000, type=int, help='number of gradient steps')
    parser.add_argument('--network_type', default=5, type=int, help='different network types one can use')
    parser.add_argument('--print_every', default=200, type=int, help='number of gradient steps between prints')
    parser.add_argument('--name', default='2body', type=str, help='only one option right now')
    parser.add_argument('--verbose', dest='verbose', action='store_true', help='verbose?')
    parser.add_argument('--seed', default=2, type=int, help='random seed')
    parser.add_argument('--save_dir', default=THIS_DIR, type=str, help='where to save the trained model')
    parser.set_defaults(feature=True)
    
    return parser.parse_args()




def train(args):

    if args.network_type==1:
        test_dim=4
        num_hidden=2
        momentum=False
        angular_momentum=False
        HPQ_trainable=True
        alpha_hpq_Q=0.001
        alpha_hpq_P=0.001
        alpha_poisson=0.001
        
        
    elif args.network_type==2:
        test_dim=2
        num_hidden=0
        momentum=False
        angular_momentum=False
        HPQ_trainable=True
        alpha_hpq_Q=0.01
        alpha_hpq_P=0.01
        alpha_poisson=0.01
  
        
    elif args.network_type==3:
        test_dim=4
        num_hidden=2
        momentum=True
        angular_momentum=False
        HPQ_trainable=True
        alpha_hpq_Q=0.001
        alpha_hpq_P=0.001
        alpha_poisson=0.001
        
 
        
    elif args.network_type==4:
        test_dim=3
        num_hidden=2
        momentum=True
        angular_momentum=False
        HPQ_trainable=False
        alpha_hpq_Q=0.0
        alpha_hpq_P=1.
        alpha_poisson=0.0
        
    else:
        test_dim=4
        num_hidden=2
        momentum=True
        angular_momentum=True
        HPQ_trainable=False
        alpha_hpq_Q=0.0
        alpha_hpq_P=1.
        alpha_poisson=0.0
    
    
    
    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    
   
    print("Training SCNN model:")
    
    nn_model_pq_PQ = pq_PQ(args.input_dim, args.hidden_dim,latent_dim_p=args.input_dim-test_dim,latent_dim_q=test_dim ,num_hidden=num_hidden,
        HPQ_trainable=HPQ_trainable, angular_momentum=angular_momentum, momentum=momentum)
    nn_model_HPQ= P_H(input_dim=args.input_dim, hidden_dim= args.hidden_dim,latent_dim_p=args.input_dim-test_dim, test_dim=test_dim, HPQ_trainable=HPQ_trainable)
    model = Generator(args.input_dim, differentiable_model_1=nn_model_pq_PQ, differentiable_model_2=nn_model_HPQ, test_dim=test_dim)

    optim = torch.optim.Adam(list(model.step_1.parameters())+list(model.step_2.parameters()),args.learn_rate, weight_decay=0)

    # arrange data
    data = get_dataset(args.name, args.save_dir,verbose=True)
    x = torch.tensor(data['coords'], requires_grad=True, dtype=torch.float32)
    test_x = torch.tensor( data['test_coords'], requires_grad=True, dtype=torch.float32)
    dxdt = torch.Tensor(data['dcoords'])
    test_dxdt = torch.Tensor(data['test_dcoords'])

    # vanilla train loop
    stats = {'train_loss': [], 'test_loss': []}
    for step in range(args.total_steps+1):

        # train step
        ixs = torch.randperm(x.shape[0])[:args.batch_size]
        dxdt_hat = model.time_derivative(x[ixs])
        loss = L2_loss(dxdt[ixs], dxdt_hat)
        loss +=model.get_loss(x[ixs],dxdt[ixs],
                              alpha_hpq_Q=alpha_hpq_Q,alpha_poisson=alpha_poisson,alpha_hpq_P=alpha_hpq_P)

        loss.backward()
        optim.step() ; optim.zero_grad()

        # run test data
        test_ixs = torch.randperm(test_x.shape[0])[:args.batch_size]
        test_dxdt_hat = model.time_derivative(test_x[test_ixs])
        test_loss = L2_loss(test_dxdt[test_ixs], test_dxdt_hat)

        # logging
        stats['train_loss'].append(loss.item())
        stats['test_loss'].append(test_loss.item())
        if args.verbose and step % args.print_every == 0:
            print("step {}, train_loss {:.4e}, test_loss {:.4e}"
                  .format(step, loss.item(), test_loss.item()))
            

    train_dxdt_hat = model.time_derivative(x)
    train_dist = (dxdt - train_dxdt_hat)**2
    test_dxdt_hat = model.time_derivative(test_x)
    test_dist = (test_dxdt - test_dxdt_hat)**2
    print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'
        .format(train_dist.mean().item(), train_dist.std().item()/np.sqrt(train_dist.shape[0]),
                test_dist.mean().item(), test_dist.std().item()/np.sqrt(test_dist.shape[0])))
    
    
    

    
    return model, stats
if __name__ == "__main__":
    args = get_args()
    model, stats = train(args)

    # save
    os.makedirs(args.save_dir) if not os.path.exists(args.save_dir) else None
    path = '{}/{}-weights-{}.tar'.format(args.save_dir, args.name, args.network_type)
    torch.save(model.state_dict(), path)


    path = '{}/{}-weights-{}-stats.pkl'.format(args.save_dir, args.name, args.network_type)
    to_pickle(stats, path)