from scripts.base_models import VAE_Baseline
import scripts.utils as utils
import torch
from fvcore.nn import FlopCountAnalysis
import numpy as np
from npy_append_array import NpyAppendArray
import os 

class LatentGraphODE(VAE_Baseline):
	def __init__(self, input_dim, latent_dim, encoder_z0, decoder, diffeq_solver,
				 z0_prior, device, obsrv_std=None):

		super(LatentGraphODE, self).__init__(
			input_dim=input_dim, latent_dim=latent_dim,
			z0_prior=z0_prior,
			device=device, obsrv_std=obsrv_std)

		self.encoder_z0 = encoder_z0
		self.diffeq_solver = diffeq_solver
		self.decoder = decoder
		self.latent_dim =latent_dim



	def get_reconstruction(self, batch_en,batch_de, batch_g,n_traj_samples=1,run_backwards=True):

        #Encoder:
		first_point_mu, first_point_std = self.encoder_z0(batch_en.x, batch_en.edge_attr,
														  batch_en.edge_index, batch_en.pos, batch_en.edge_same,
														  batch_en.batch, batch_en.y)  # [num_ball,10]
		# 
		means_z0 = first_point_mu.repeat(n_traj_samples,1,1) #[3,num_ball,10]
		sigmas_z0 = first_point_std.repeat(n_traj_samples,1,1) #[3,num_ball,10]
		first_point_enc = utils.sample_standard_gaussian(means_z0, sigmas_z0) #[3,num_ball,10]
		# print(first_point_enc.shape, "first_point_enc")
		# for batch in range(128):
			
		# 		enc = first_point_enc[0, 5*batch   ,:  ]
		# 		enc = enc.detach().cpu().numpy() 
		# 		# enc.tolist()
		# 		# print(enc.shape, "enc")
		# 		filename = 'intrinsic_dimension_estimation/out_10_5_all.npy'
		# 		existing_array = np.load(filename, allow_pickle=True) if os.path.exists(filename) else np.array([]) 
		# 		# print(existing_array.shape, "existing_array")
		# 		if existing_array.shape[0] == 0:
		# 			np.save(filename, enc)
		# 		else:

		# 			np.save(filename, np.vstack((existing_array, enc)))


		first_point_std = first_point_std.abs()

		time_steps_to_predict = batch_de["time_steps"]



		assert (torch.sum(first_point_std < 0) == 0.)
		assert (not torch.isnan(time_steps_to_predict).any())
		assert (not torch.isnan(first_point_enc).any())


		# print(time_steps_to_predict)
		# ODE:Shape of sol_y [n_traj_samples, n_samples, n_timepoints, n_latents]
		# print(time_steps_to_predict)
		# time_series = np.arange(60)/60.0 
		# time_steps_to_predict = torch.from_numpy(time_series).float().to(self.device)[:30]
		# print(time_steps_to_predict, "time_steps_to_predict")
		sol_y = self.diffeq_solver(first_point_enc, time_steps_to_predict, batch_g)
		# print(sol_y.shape, "sol_y")
		# for batch in range(128):
		# 	enc = sol_y[0, 5*batch: 5*batch   ,: ,: ]
		# 	enc = enc.detach().cpu().numpy() 
		# 	filename = 'intrinsic_dimension_estimation/out_10_5_all.npy' 
		# 	existing_array = np.load(filename, allow_pickle=True) if os.path.exists(filename) else np.array([])
		# 	if existing_array.shape[0] == 0:
		# 		array_t = enc[0,:]
		# 		for t in range(1,30):
		# 			array_t = np.vstack((array_t, enc[t,:])) 
		# 		np.save(filename, array_t) 
		# 	else:
		# 		array_t = enc[0,:]
		# 		for t in range(1,30):
		# 			array_t = np.vstack((array_t, enc[t,:]))
		# 		np.save(filename, np.vstack((existing_array, array_t)))
					

        # Decoder:
		pred_x = self.decoder(sol_y)
		# print("pred_x",torch.max(pred_x), torch.min(pred_x))
		# print(torch.max(batch_de["data"]), torch.min(batch_de["data"]), "batch_de")

		# print("pred_x",pred_x.shape)
		# print("batch_de",batch_de.shape)


		all_extra_info = {
			"first_point": (torch.unsqueeze(first_point_mu,0), torch.unsqueeze(first_point_std,0), first_point_enc),
			"latent_traj": sol_y.detach()
		}

		return pred_x, all_extra_info, None



