import torch
import torch.nn as nn
import numpy as np

import os, sys
from torchdiffeq import odeint as odeint
import matplotlib.pyplot as plt

def trajectory_generator(system, path='.', nentropy=1, plot=False, device=torch.device("cpu")):
	if system == 'dno':
		# adjustable parameters
		dt = 0.001     
		noise = 0.     # for study of noisy measurements, we use noise=0.01, 0.02; otherwise we leave it as 0.
		total_steps = 180001
		t = torch.linspace(0, (total_steps)*dt, total_steps+1).to(device)
		
		# system
		# m = T = 1
		# gamma = .25
		# k = 6.
		def rhs_torch(t,x):
			return torch.cat( (x[:,1], -torch.sin(x[:,0])-0.01*x[:,1], 0.01*x[:,1]**2), axis=-1)
		
		# simulation parameters
		np.random.seed(2)
		n = 3 
		
		# dataset 
		odeint_method = 'dopri5'
		
		x_init = torch.tensor([2.0,0.0,0.0]).to(device).unsqueeze(0)
		sol = odeint(rhs_torch,x_init,t,method=odeint_method).to(device).squeeze().detach().numpy()

		sol = sol[:,:2]
		split = np.asarray([60000,80000])
		
	elif system == 'tgc':
		dt = 0.001
		total_steps = 100000
		t = torch.linspace(0, (total_steps)*dt, total_steps+1).to(device)

		Const = 102.2476703501216
		alpha = 1. 
		m = 1.0
		def rhs_torch(t,x):
			E1 = torch.exp( (2.0/3.0) * ( x[:,2] - torch.log(x[:,0]) - Const ) ) 
			E2 = torch.exp( (2.0/3.0) * ( x[:,3] - torch.log(2.0 - x[:,0]) - Const ) )
			return torch.cat( (x[:,1]/m, (2.0/3.0)*( E1/x[:,0] - E2/(2-x[:,0]) ), 9.0*alpha/(4.0*E1)*(1.0/E1-1.0/E2), - 9.0*alpha/(4.0*E2)*(1.0/E1-1.0/E2)), axis=-1)
		S1_init = 1.5*np.log(2.0)+np.log(1)+Const
		S2_init = 1.5*np.log(2.0)+np.log(2-1)+Const

		odeint_method = 'rk4'

		x_init = torch.tensor([1.0,2.0,S1_init,S2_init]).to(device).unsqueeze(0)
		sol = odeint(rhs_torch,x_init,t,method=odeint_method).to(device).squeeze().detach().numpy()
		sol = sol[:,:2]

		split = np.asarray([20000,30000])
	elif system == 'tdp':
		dt = 0.02
		total_steps = 2000
		t = torch.linspace(0, (total_steps)*dt, total_steps+1).to(device)

		lb = [0.9,-0.1,2.1,-0.1,-0.1,1.9,0.9,-0.1,0.9,0.1]
		ub = [1.1,0.1,2.3,0.1,0.1,2.1,1.1,0.1,1.1,0.3]

		def rhs_torch(t,x):
			q1, q2 = x[:,:2], x[:,2:4]
			p1, p2 = x[:,4:6], x[:,6:8]
			S1, S2 = x[:,8], x[:,9]
			lam1 = torch.norm(q1)
			lam2 = torch.norm(q2-q1)
			T1 = 1/lam1*torch.exp(S1)
			T2 = 1/lam2*torch.exp(S2)
			de1dq1 = 1/lam1**2*(torch.log(lam1)+1-T1)*q1
			de1dq2 = 0
			de2dq1 = 1/lam2**2*(torch.log(lam2)+1-T2)*(q1-q2)
			de2dq2 = -de2dq1
			out = torch.zeros(1,10)
			out[:,:2] = p1
			out[:,2:4] = p2
			out[:,4:6] = -(de1dq1 + de2dq1)
			out[:,6:8] = -(de1dq2 + de2dq2)
			out[:,-2] = T2/T1 - 1
			out[:,-1] = T1/T2 - 1
			return out

		odeint_method = 'rk4'

		sols = []

		nsample = 100
		for i in range(100):
			x = np.random.rand(10)
			x_init = torch.tensor(x*lb + (1-x)*ub).unsqueeze(0)
			sol = odeint(rhs_torch,x_init,t,method=odeint_method).to(device).transpose(0,1)#.squeeze().detach().numpy()
			sol = sol[:,:,:8]
			sols.append(sol)

		split = np.asarray([5000,10000])

		sol = torch.cat(sols, dim=0)
		sData = [torch.unsqueeze(torch.linspace(0.,1.,len(t),requires_grad=True),-1) for i in range(nentropy)]
		sData = torch.cat(sData, dim=-1).repeat(100,1,1)
		print(sol.shape, sData.shape)
		sol = torch.cat((sol, sData), dim=-1)
		return sol, t, split

	if nentropy != 0:
		sData = np.asarray([np.linspace(0.,1.,len(t)) for i in range(nentropy)]).T
		sol = np.concatenate((sol, sData), axis=-1)
	if plot: plot_trajectory(sol, t, path) 
	return sol, t, split

def plot_trajectory(sol, t, path):
	plt.figure()
	plt.tight_layout()
	save_file = os.path.join(path,"gt_traj.png")
	fig = plt.figure(figsize=(4,2))
	axes = []
	axes.append(fig.add_subplot(1,1,1))
	n = sol.shape[1]
	axes[0].plot(t,sol[:,:8],lw=2)
	#for i in range(n):
	#	axes[i].plot(t,sol,lw=2,color='k')
	#	axes[i].plot(t,pred_y[:,i].detach().numpy(),lw=2,color='c',ls='--')
	plt.savefig(save_file)
	plt.close(fig)
	plt.close('all')
	plt.clf()
	exit()
