import torch 
import torch.nn as nn
import torch.optim as optim 

import numpy as np
import os, sys

PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PARENT_DIR)

import lib.utils as utils
import data.data_gen as data_gen
import model as model
import train as train

from torchdiffeq import odeint as odeint
import lib.odefunc as odefunc

import matplotlib.pyplot as plt

import random
from random import SystemRandom


from torchdiffeq import odeint as odeint
import argparse

parser = argparse.ArgumentParser(description='.')
parser.add_argument('--r', type=int, default=0, help='random_seed')
parser.add_argument('--system', type=str, default='dno', help='dynamical system')
parser.add_argument('--model', type=str, default='gnodev2', help='model: node (black-box), spnn (penalty-based), GFINN (Zhang, et al, 2021), GNODEv1 (Lee et al, 2021)')
parser.add_argument('--nunit', type=int, default=25, help='random_seed')
parser.add_argument('--nlayer', type=int, default=4, help='random_seed')

parser.add_argument('--plot', action = 'store_true')

# spnn / gfinn / gnodev1
parser.add_argument('--lE', type=int, default=3, help='random_seed')
parser.add_argument('--lS', type=int, default=3, help='random_seed')
parser.add_argument('--nE', type=int, default=15, help='random_seed')
parser.add_argument('--nS', type=int, default=15, help='random_seed')

# gnodev1
parser.add_argument('--D1', type=int, default=3, help='random_seed')
parser.add_argument('--D2', type=int, default=3, help='random_seed')

# gfinn
parser.add_argument('--K', type=int, default=3, help='random_seed')

# gnodev2 (the proposed one)
#parser.add_argument('--lE', type=int, default=2, 
#		    help='number of hidden layers in the energy function')
#parser.add_argument('--lS', type=int, default=2, 
#		    help='number of hidden layers in the entropy function')
parser.add_argument('--lA', type=int, default=3, 
		    help='number of hidden layers in (n x n) skew-sym reversible part')
parser.add_argument('--lB', type=int, default=3, 
		    help='number of hidden layers in (n x r) irreversible part')
parser.add_argument('--lD', type=int, default=3, 
		    help='number of hidden layers in (r x r) interaction matrix field')
#parser.add_argument('--nE', type=int, default=15, 
#		    help='dimension of hidden layers in energy function')
#parser.add_argument('--nS', type=int, default=15, 
#		    help='dimension of hidden layers in entropy function')
parser.add_argument('--nA', type=int, default=30, 
		    help='dimension of hidden layers in skew-sym reversible part')
parser.add_argument('--nB', type=int, default=30, 
		    help='dimension of hidden layers in irreversible part')
parser.add_argument('--nD', type=int, default=30, 
		    help='dimension of hidden layers in interaction matrix')
parser.add_argument('--D', type=int, default=3, 
		    help='dimension r of interaction matrix')
parser.add_argument('--C2', type=int, default=3, 
		    help='inner dimension of factor in interaction matrix')

parser.add_argument('--nentropy', type=int, default=1, help='abc')
parser.add_argument('--id', type=int, default=0, help='random_seed')

args = parser.parse_args()

torch.set_default_dtype(torch.float64)
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

seed = args.r 
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

save_path = 'experiments/' + args.system + '/' + args.model + '/'
utils.makedirs(save_path)
experimentID = args.id
ckpt_path = os.path.join(save_path, str(experimentID) + '.ckpt')
fig_save_path = os.path.join(save_path, str(experimentID))
utils.makedirs(fig_save_path)
print(ckpt_path)

train_data, t, split = data_gen.trajectory_generator(args.system, fig_save_path, args.nentropy, args.plot)
print(train_data.shape, t.shape)

train_data = torch.tensor(train_data,requires_grad=True)

if args.model == 'spnn':
	kwargs = {"lE": args.lE, "lS": args.lS, "nE": args.nE, "nS": args.nS}
elif args.model == 'gfinn':
	kwargs = {"lE": args.lE, "lS": args.lS, "nE": args.nE, "nS": args.nS, "K": args.K}
elif args.model == 'gnodev1':
	kwargs = {"lE": args.lE, "lS": args.lS, "nE": args.nE, "nS": args.nS, "D1": args.D1, "D2": args.D2}
elif args.model == 'gnodev2':
	kwargs = {"lE": args.lE, "lS": args.lS, "lA": args.lA, "lB": args.lB, "lD": args.lD, "nE": args.nE, "nS": args.nS, "nA": args.nA, "nB": args.nB, "nD": args.nD, "D": args.D, "C2": args.C2}


nState = train_data.shape[1]
if args.system=='tdp':
	nState = train_data.shape[2]

dyn = model.model_selection(args.model, nState, args.nlayer, args.nunit, **kwargs)

ckpt = torch.load(ckpt_path)
dyn.load_state_dict(ckpt['state_dict'])

dof = nState - args.nentropy
if args.system == 'tdp':
	pred_y = odeint(dyn, train_data[80:90:,0,:], t, rtol=1e-7, atol=1e-9, method='dopri5').to(device).squeeze().transpose(0,1)
	val_loss = torch.mean(torch.abs(pred_y[:,:,:dof] - train_data[80:90,:,:dof]).sum(axis=1))
	pred_y = odeint(dyn, train_data[90:,0,:], t, rtol=1e-7, atol=1e-9, method='dopri5').to(device).squeeze().transpose(0,1)
	test_loss = torch.mean(torch.abs(pred_y[:,:,:dof] - train_data[90:,:,:dof]).sum(axis=1))
else:
	pred_y = odeint(dyn, train_data[:1,:], t, rtol=1e-7, atol=1e-9, method='dopri5').to(device).squeeze()
	val_loss = torch.mean(torch.abs(pred_y[split[0]:split[1],:dof] - train_data[split[0]:split[1],:dof]).sum(axis=1))
	test_loss = torch.mean(torch.abs(pred_y[split[1]:,:dof] - train_data[split[1]:,:dof]).sum(axis=1))
	all_loss = torch.mean(torch.abs(pred_y[:,:dof] - train_data[:,:dof]))
	all_mse_loss = torch.mean(torch.pow(pred_y[:,:dof] - train_data[:,:dof], 2.0))
	print(all_loss.item(), all_mse_loss.item())
print(val_loss.item(), test_loss.item())
		

colors = ['red', 'coral', 'forestgreen', 'blue','red', 'coral', 'forestgreen', 'blue','red', 'coral', 'forestgreen', 'blue']

if args.system == 'tdp':
	plt.figure()
	plt.tight_layout()
	save_file = os.path.join(fig_save_path,"best_multiple_panels.png")
	fig = plt.figure(figsize=(8,12))
	axes = []
	for i in range(dof):
		axes.append(fig.add_subplot(dof//2,2,i+1))
		axes[i].plot(t,train_data[90,:,i].detach().numpy(),lw=2,color='k')
		axes[i].plot(t,pred_y[0,:,i].detach().numpy(),lw=2,color=colors[i],ls='--')
		axes[i].set_xlim([0,t[-1]])
	plt.savefig(save_file)
	plt.close(fig)
	plt.close('all')
	plt.clf()
	
	
	plt.figure()
	plt.tight_layout()
	save_file = os.path.join(fig_save_path,"best_single_panel.png")
	fig = plt.figure(figsize=(8,2))
	axes = []
	axes.append(fig.add_subplot(1,1,1))
	for i in range(dof):
		axes[0].plot(t,train_data[90,:,i].detach().numpy(),lw=2,color='k')
		axes[0].plot(t,pred_y[0,:,i].detach().numpy(),lw=2,color=colors[i],ls='--')
	axes[0].set_xlim([0,t[-1]])
	plt.savefig(save_file)
	plt.close(fig)
	plt.close('all')
	plt.clf()

	plt.figure()
	plt.tight_layout()
	save_file = os.path.join(fig_save_path,"best_single_panel_mse.png")
	fig = plt.figure(figsize=(8,2))
	axes = []
	axes.append(fig.add_subplot(1,1,1))

	error = torch.pow(pred_y[:,:,:dof] - train_data[90:,:,:dof],2).sum(axis=2)
	print(error.shape)
	error_mean = torch.mean(error, axis=0)
	
	axes[0].semilogy(t,error_mean.detach().numpy(),lw=2,color='k')
	axes[0].set_xlim([0,t[-1]])
	plt.savefig(save_file)
	plt.close(fig)
	plt.close('all')
	plt.clf()
else:
	plt.figure()
	plt.tight_layout()
	save_file = os.path.join(fig_save_path,"best_multiple_panels.png")
	fig = plt.figure(figsize=(8,12))
	axes = []
	for i in range(dof):
		axes.append(fig.add_subplot(dof//2,2,i+1))
		axes[i].plot(t,train_data[:,i].detach().numpy(),lw=2,color='k')
		axes[i].plot(t,pred_y[:,i].detach().numpy(),lw=2,color=colors[i],ls='--')
		axes[i].set_xlim([0,t[-1]])
		axes[i].axvspan(t[split[0]], t[split[1]], facecolor='yellow', alpha=0.1)
		axes[i].axvspan(t[split[1]], t[-1], facecolor='lightpink', alpha=0.1)
	plt.savefig(save_file)
	plt.close(fig)
	plt.close('all')
	plt.clf()
	
	
	plt.figure()
	plt.tight_layout()
	save_file = os.path.join(fig_save_path,"best_single_panel.png")
	fig = plt.figure(figsize=(8,2))
	axes = []
	axes.append(fig.add_subplot(1,1,1))
	for i in range(dof):
		axes[0].plot(t,train_data[:,i].detach().numpy(),lw=2,color='k')
		axes[0].plot(t,pred_y[:,i].detach().numpy(),lw=2,color=colors[i],ls='--')
	axes[0].set_xlim([0,t[-1]])
	axes[0].axvspan(t[split[0]], t[split[1]], facecolor='yellow', alpha=0.1)
	axes[0].axvspan(t[split[1]], t[-1], facecolor='lightpink', alpha=0.1)
	axes[0].set_xlim([0,t[-1]])
	plt.savefig(save_file)
	plt.close(fig)
	plt.close('all')
	plt.clf()
