import torch
import random
import numpy as np


class Data(torch.utils.data.Dataset):
	def __init__(self, num_slots = 3, num_operations = 3, num_examples = 10000):
		super().__init__()
		self.entities = []
		
		self.operations = []
		self.masks = []
		for i in range(num_examples):
			init_numbers = []
			for k in range(num_slots):
				init_numbers.append(np.array([random.uniform(0, 1), random.uniform(0, 1)]))
			init_numbers = [np.stack(init_numbers, axis  = 0)]
			 
			masks_ = []
			ops = []
			for op in range(num_operations):
				new_state = np.array(init_numbers[-1])
				number_1 = random.randint(0, num_slots - 1)
				number_2 = 0 if num_slots == 1 else 1
				mask = np.zeros(num_slots)
				mask[number_1] = 1

				operation = random.randint(0, 3)				
				if operation == 0:
					new_state[number_1, 0] += new_state[number_2, 0]
				elif operation == 1:
					new_state[number_1, 1] += new_state[number_2, 1] 
				elif operation == 2:
					new_state[number_1, 0] -= new_state[number_2, 0]
				elif operation == 3:
					new_state[number_1, 1] -= new_state[number_2, 1]

				init_numbers.append(new_state)
				ops.append([number_1, number_2, operation])
				masks_.append(mask)
			self.entities.append(init_numbers)
			self.operations.append(ops)
			
			self.masks.append(masks_)

	def __len__(self):
		return len(self.entities)

	def __getitem__(self, i):
		numbers = self.entities[i]
		ops = self.operations[i]
		mask = self.masks[i]

		numbers = np.stack(numbers, axis = 0)
		numbers = torch.from_numpy(numbers)

		mask = np.stack(mask, axis = 0)
		mask = torch.from_numpy(mask)

		#ops = np.stack(ops, axis = 1)
		#ops = torch.from_numpy(ops)

		return numbers, ops, mask