import torch 
import torch.nn as nn
import torch.optim as optim 

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import lib.utils as utils
import data.data_gen as data_gen
import model as model
import train as train

import random
from random import SystemRandom


from torchdiffeq import odeint as odeint
import argparse

parser = argparse.ArgumentParser(description='.')
parser.add_argument('--r', type=int, default=0, help='random_seed')
parser.add_argument('--system', type=str, default='dno', help='dynamical system')
parser.add_argument('--model', type=str, default='gnodev2', help='model: node (black-box), spnn (penalty-based), GFINN (Zhang, et al, 2021), GNODEv1 (Lee et al, 2021)')
parser.add_argument('--nunit', type=int, default=25, help='random_seed')
parser.add_argument('--nlayer', type=int, default=4, help='random_seed')


parser.add_argument('--lr', type=float, default=1e-2, help='random_seed')
parser.add_argument('--lMB', type=int, default=10, help='random_seed')

parser.add_argument('--MB', action = 'store_true')
parser.add_argument('--plot', action = 'store_true')

# spnn / gfinn / gnodev1
parser.add_argument('--lE', type=int, default=3, help='random_seed')
parser.add_argument('--lS', type=int, default=3, help='random_seed')
parser.add_argument('--nE', type=int, default=15, help='random_seed')
parser.add_argument('--nS', type=int, default=15, help='random_seed')

# gnodev1
parser.add_argument('--D1', type=int, default=3, help='random_seed')
parser.add_argument('--D2', type=int, default=3, help='random_seed')

# gfinn
parser.add_argument('--K', type=int, default=3, help='random_seed')

# gnodev2 (the proposed one)
parser.add_argument('--lA', type=int, default=3, 
		    help='number of hidden layers in (n x n) skew-sym reversible part')
parser.add_argument('--lB', type=int, default=3, 
		    help='number of hidden layers in (n x r) irreversible part')
parser.add_argument('--lD', type=int, default=3, 
		    help='number of hidden layers in (r x r) interaction matrix field')
parser.add_argument('--nA', type=int, default=30, 
		    help='dimension of hidden layers in skew-sym reversible part')
parser.add_argument('--nB', type=int, default=30, 
		    help='dimension of hidden layers in irreversible part')
parser.add_argument('--nD', type=int, default=30, 
		    help='dimension of hidden layers in interaction matrix')
parser.add_argument('--D', type=int, default=3, 
		    help='dimension r of interaction matrix')
parser.add_argument('--C2', type=int, default=3, 
		    help='inner dimension of factor in interaction matrix')

parser.add_argument('--nentropy', type=int, default=1, help='abc')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

seed = args.r 
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

save_path = 'experiments/' + args.system + '/' + args.model + '/'
utils.makedirs(save_path)
experimentID = int(SystemRandom().random()*100000)
ckpt_path = os.path.join(save_path, str(experimentID) + '.ckpt')
fig_save_path = os.path.join(save_path, str(experimentID))
utils.makedirs(fig_save_path)
print(ckpt_path)

train_data, t, split = data_gen.trajectory_generator(args.system, fig_save_path, args.nentropy, plot=args.plot)
print(train_data.shape, t.shape)

train_data = torch.tensor(train_data,requires_grad=True)

if args.model == 'spnn':
	kwargs = {"lE": args.lE, "lS": args.lS, "nE": args.nE, "nS": args.nS}
elif args.model == 'gfinn':
	kwargs = {"lE": args.lE, "lS": args.lS, "nE": args.nE, "nS": args.nS, "K": args.K}
elif args.model == 'gnodev1':
	kwargs = {"lE": args.lE, "lS": args.lS, "nE": args.nE, "nS": args.nS, "D1": args.D1, "D2": args.D2}
elif args.model == 'gnodev2':
	kwargs = {"lE": args.lE, "lS": args.lS, "lA": args.lA, "lB": args.lB, "lD": args.lD, "nE": args.nE, "nS": args.nS, "nA": args.nA, "nB": args.nB, "nD": args.nD, "D": args.D, "C2": args.C2}
#dyn = model.model_selection(args.model, train_data.shape[1], args.nlayer, args.nunit, lE=args.lE, lS=args.lS, nE=args.nE, nS=args.nS, D1=args.D1, D2=args.D2)


nState = train_data.shape[1]
if args.system=='tdp':
	nState = train_data.shape[2]
dyn = model.model_selection(args.model, nState, args.nlayer, args.nunit, **kwargs)

dof = nState - args.nentropy
train.train(dyn, args.system, args.lr, train_data, t, dof, split, args.lMB, ckpt_path, fig_save_path, args.MB)
