from torch.utils.data import Dataset
import numpy as np
import torch
from mapper.attribute_list import ATTRIBUTE_LIST


class LatentsDataset(Dataset):

	def __init__(self, latents, poses, infos, opts, status='train', net=None, device='cuda:0'):
		self.latents = latents
		self.poses = poses
		self.infos = infos
		self.opts = opts
		self.status = status
		self.device = device
		self.net = net
  
		self.attribute_list = ATTRIBUTE_LIST
		self.attribute_list_len = len(self.attribute_list)

	def __getitem__(self, index):
		selected_attributes_num = np.random.choice(np.arange(1, 4), p=[1.0, 0., 0.])
		selected_attributes = np.random.choice(np.arange(self.attribute_list_len), size=selected_attributes_num, replace=False)
		selected_description = '+'.join([self.attribute_list[idx]['attr'] for idx in selected_attributes])

		attr = ATTRIBUTE_LIST[selected_attributes[0]]['attr']
		while True:
			ref_idx = np.random.choice(np.arange(self.__len__()))
			if ref_idx == index:
				continue
			
			if not (attr == 'hairstyle' or attr == 'beard'):
				break
			if np.abs(self.infos[index][0]['age']-self.infos[ref_idx][0]['age']) < 20 and \
						self.infos[index][0]['dominant_gender'] == self.infos[ref_idx][0]['dominant_gender']:
				break

		ref_latent, ref_pose = self.latents[ref_idx].to(self.device), self.poses[ref_idx].to(self.device)
		with torch.no_grad():
			ref_img = self.net.decoder.synthesis(ref_latent[None], ref_pose[None])['image'][0]

		return self.latents[index], self.poses[index], selected_attributes, selected_description, ref_img, ref_pose, ref_latent

	def __len__(self):
		return self.latents.shape[0]
