import torch 
import torch.nn as nn
import torch.optim as optim 

from torchdiffeq import odeint as odeint

import os, sys
import matplotlib.pyplot as plt
import lib.utils as utils
from lib.odefunc import ODEfunc_SPNN, ODEfunc_GFINN

import numpy as np

def train(model, system, lr, data, t, dof, split, lMB, ckpt_path, fig_path, mini_batch=True, device=torch.device("cpu")):
	best_loss = 1e30
	params = model.parameters()
	optimizer = optim.Adamax(params, lr=lr, weight_decay=1e-4)
	scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.9985)
	frame = 0



	if system == 'tdp':
		train_data = data[:80,:,:]
		train_t = t
		
		for itr in range(1, 30001):
			optimizer.zero_grad()

			r = torch.from_numpy(np.random.choice(np.arange(80, dtype=np.int64), 10, replace=False))
			pred_y = odeint(model, data[r,0,:], train_t, rtol=1e-7, atol=1e-9, method='dopri5').to(device).squeeze().transpose(0,1)
			loss = torch.mean(torch.abs(pred_y[:,:,:dof] - train_data[r,:,:dof]).sum(axis=1))

			if isinstance(model, ODEfunc_SPNN) or isinstance(model, ODEfunc_GFINN):
				LdS, MdE = model.get_penalty()
				loss += 1e-2*(torch.mean(torch.sum(LdS**2,axis=1) + torch.sum(MdE**2,axis=1)))
			loss.backward()
			optimizer.step()

			pred_y = odeint(model, data[80:90,0,:], t, rtol=1e-7, atol=1e-9, method='dopri5').to(device).squeeze().transpose(0,1)

			val_loss = torch.mean(torch.abs(pred_y[:,:,:dof] - data[80:90,:,:dof]).sum(axis=1))

			print(itr, loss.item(), val_loss.item())
			if best_loss > val_loss.item():
				print('saving ode...', val_loss.item())
				torch.save({
					'state_dict': model.state_dict(),                                           
					}, ckpt_path)
				best_loss = val_loss.item()
			
			if itr % 50 == 0:
				plt.figure()
				plt.tight_layout()
				save_file = os.path.join(fig_path,"image_{:03d}.png".format(frame))
				if dof == 8:
					fig = plt.figure(figsize=(8,8))
				else:
					fig = plt.figure(figsize=(8,4))

				axes = []
				for i in range(dof):
					axes.append(fig.add_subplot(dof//2,2,i+1))
					axes[i].plot(t,data[80,:,i].detach().numpy(),lw=2,color='k')
					axes[i].plot(t,pred_y[0,:,i].detach().numpy(),lw=2,color='c',ls='--')
				plt.savefig(save_file)
				plt.close(fig)
				plt.close('all')
				plt.clf()
				frame += 1
	else:
		train_data = data[:split[0],:]
		train_t = t[:split[0]]
		
		for itr in range(1, 5001):
			optimizer.zero_grad()

			if mini_batch:
				batch_y0, batch_t, batch_y = utils.get_batch(train_data, train_t, lMB)
				pred_y = odeint(model, batch_y0, batch_t, rtol=1e-7, atol=1e-9, method='dopri5').to(device).squeeze().transpose(0,1)
				loss = torch.mean(torch.abs(pred_y - batch_y).sum(axis=1))
			else:
				pred_y = odeint(model, data[:1,:], train_t, rtol=1e-7, atol=1e-9, method='dopri5').to(device).squeeze()
				# DNO, TGC
				loss = torch.mean(torch.abs(pred_y[:,:dof] - train_data[:,:dof]).sum(axis=1))

			if isinstance(model, ODEfunc_SPNN) or isinstance(model, ODEfunc_GFINN):
				LdS, MdE = model.get_penalty()
				loss += 1e-2*(torch.mean(torch.sum(LdS**2,axis=1) + torch.sum(MdE**2,axis=1)))
			loss.backward()
			optimizer.step()

			pred_y = odeint(model, data[:1,:], t, rtol=1e-7, atol=1e-9, method='dopri5').to(device).squeeze()

			val_loss = torch.mean(torch.abs(pred_y[split[0]:split[1],:dof] - data[split[0]:split[1],:dof]).sum(axis=1))

			print(itr, loss.item(), val_loss.item())
			if best_loss > val_loss.item():
				test_loss = torch.mean(torch.abs(pred_y[split[1]:,:dof] - data[split[1]:,:dof]).sum(axis=1))
				print('saving ode...', val_loss.item())
				torch.save({
					'state_dict': model.state_dict(),                                           
					}, ckpt_path)
				best_loss = val_loss.item()
		
			
			
			if itr % 50 == 0:
				plt.figure()
				plt.tight_layout()
				save_file = os.path.join(fig_path,"image_{:03d}.png".format(frame))
				if dof == 8:
					fig = plt.figure(figsize=(8,8))
				else:
					fig = plt.figure(figsize=(8,4))

				axes = []
				for i in range(dof):
					axes.append(fig.add_subplot(dof//2,2,i+1))
					axes[i].plot(t,data[:,i].detach().numpy(),lw=2,color='k')
					axes[i].plot(t,pred_y[:,i].detach().numpy(),lw=2,color='c',ls='--')
				plt.savefig(save_file)
				plt.close(fig)
				plt.close('all')
				plt.clf()
				frame += 1
