import numpy as np
import torch
from torch.utils.data import Dataset
# from dnri.utils import data_utils
import matplotlib.pyplot as plt
import numpy as np


def normalize(data, data_max, data_min):
	return (data - data_min) * 2 / (data_max - data_min) - 1


def unnormalize(data, data_max, data_min):
	return (data + 1) * (data_max - data_min) / 2. + data_min


def get_edge_inds(num_vars):
	edges = []
	for i in range(num_vars):
		for j in range(num_vars):
			if i == j:
				continue
			edges.append([i, j])
	return edges


class BasketballData(Dataset):
	def __init__(self, name, data_path, mode, params, test_full=False, num_in_path=True, has_edges=True, transpose_data=True, max_len=None):
		self.name = name
		self.data_path = data_path
		self.mode = mode
		self.params = params
		self.num_in_path = num_in_path
		# Get preprocessing stats.
		loc_max, loc_min, vel_max, vel_min = self._get_normalize_stats()
		self.loc_max = loc_max
		self.loc_min = loc_min
		self.vel_max = vel_max
		self.vel_min = vel_min
		self.test_full = test_full
		self.max_len = max_len

		# Load data.
		self._load_data(transpose_data)
		
	def __getitem__(self, index):
		if self.max_len is not None:
			inputs = self.feat[index, :self.max_len]
		else:
			inputs = self.feat[index]
		return {'inputs': inputs}

	def __len__(self, ):
		return self.feat.shape[0]

	def _get_normalize_stats(self,):
		train_loc = np.load(self._get_npy_path('loc', 'train'))
		train_vel = np.load(self._get_npy_path('vel', 'train'))
		return train_loc.max(), train_loc.min(), train_vel.max(), train_vel.min()

	def _load_data(self, transpose_data):
		# Load data
		self.loc_feat = np.load(self._get_npy_path('loc', self.mode))
		self.vel_feat = np.load(self._get_npy_path('vel', self.mode))

		# Perform preprocessing.
		self.loc_feat = normalize(
				self.loc_feat, self.loc_max, self.loc_min)
		print(self.loc_feat.shape, "loc_feat.shape")
		print(self.vel_feat.shape, "vel_feat.shape")

		self.vel_feat = normalize(
				self.vel_feat, self.vel_max, self.vel_min)
		

		# Reshape [num_sims, num_timesteps, num_agents, num_dims]
		if transpose_data:
			self.loc_feat = np.transpose(self.loc_feat, [0, 1, 3, 2])
			self.vel_feat = np.transpose(self.vel_feat, [0, 1, 3, 2])
		self.feat = np.concatenate([self.loc_feat, self.vel_feat], axis=-2)
		# print(self.feat.shape, "feat.shape")

		# Convert to pytorch cuda tensor.
		self.feat = torch.from_numpy(
				np.array(self.feat, dtype=np.float32))  # .cuda()

		# Exlucde self edges.
		num_atoms = self.params['num_agents']
		off_diag_idx = np.ravel_multi_index(
				np.where(np.ones((num_atoms, num_atoms)) - np.eye(num_atoms)),
				[num_atoms, num_atoms])

	def _get_npy_path(self, feat, mode):
		if self.num_in_path:
			return '%s/%s_%s_%s%s.npy' % (self.data_path,
										  feat,
										  mode,
										  self.name,
										  self.params['num_agents'])
		else:
			return '%s/%s_%s_%s.npy' % (self.data_path,
										feat,
										mode,
										self.name) 


if __name__ == '__main__':
	bball_dataset = BasketballData(
		'bball', '/hdd2/extra_home/hkumawat6/Projects/AttentionNet/data_files/basketball', 'train', {'num_agents': 100}, num_in_path=False)

	print(bball_dataset.feat.shape)
	for i in range(5):
		plt.scatter(bball_dataset.feat[1, :, 0, i], bball_dataset.feat[1, :, 1, i], c = np.random.rand(3,))
	plt.show()