""" Code for loading data. """
import imageio
import glob
import numpy as np
import os
import random
import pickle
import tensorflow as tf
import pandas as pd

from skimage import transform

import absl.app
import absl.flags
from absl import logging

FLAGS = absl.flags.FLAGS

class OnlineDataGenerator(object):
	"""
	Data Generator capable of generating batches of rainbow mnist or miniimagenet data online.
	"""
	def __init__(self, per_task_batch_size=4, config={}):
		"""
		Args:
			num_samples_per_task: num samples to generate per class in one batch
			batch_size: size of meta batch size (e.g. number of functions)
		"""
		self.per_task_batch_size = per_task_batch_size
		self.dim_state_input = None

		if FLAGS.datasource == 'cont_rainbow_mnist':
			# number of classes should be set to 1 for rainbow_mnist ( but dim output is 10 )
			self.img_size = config.get('img_size', (28, 28))
			self.dim_input = np.prod(self.img_size) * 3
			self.dim_output = 10
			self.rotations = config.get('rotations', [0])
			self.generate = self.generate_cont_rainbow_mnist_batch

			with open(FLAGS.data_file, 'rb') as fin:
				self.data = pickle.load(fin)
			print('Done loading data.')

			self.data = self.data['train'] + self.data['val']
			random.seed(FLAGS.seed)
			random.shuffle(self.data)

			self.cur_task = FLAGS.cur_task
			self.cur_task_batch_id = 1 if self.cur_task == 0 else (self.data[self.cur_task]['images'].shape[0]-100) // self.per_task_batch_size
			self.num_tasks = FLAGS.cur_task + 1   # total number of tasks with data so far
			self.task_batches = {}
			for i in range(self.num_tasks):
				self.task_batches[i] = list(range(self.cur_task_batch_id))
		elif FLAGS.datasource == 'cont_pose':
			# number of classes should be set to 1 for rainbow_mnist ( but dim output is 10 )
			self.img_size = config.get('img_size', (128, 128))
			self.dim_input = np.prod(self.img_size) * 1
			self.dim_output = 1
			self.rotations = config.get('rotations', [0])
			self.generate = self.generate_cont_pose_batch

			with open(FLAGS.data_file, 'rb') as fin:
				self.data = pickle.load(fin)
			print('Done loading data.')

			random.seed(FLAGS.seed)
			random.shuffle(self.data)

			self.cur_task = FLAGS.cur_task
			self.cur_task_batch_id = 1 if self.cur_task == 0 else (self.data[self.cur_task]['images'].shape[0]-10) // self.per_task_batch_size
			self.num_tasks = FLAGS.cur_task + 1   # total number of tasks with data so far
			self.task_batches = {}
			for i in range(self.num_tasks):
				self.task_batches[i] = list(range(self.cur_task_batch_id))
		elif FLAGS.datasource == 'cont_miniimagenet':
			self.img_size = config.get('img_size', (84, 84))
			self.dim_input = np.prod(self.img_size)*6  # two images passed in.
			self.dim_output = 1
			self.rotations = config.get('rotations', [0])
			self.generate = self.generate_cont_miniimagenet_batch

			with open(FLAGS.data_file, 'rb') as fin:
				self.data = pickle.load(fin)
			print('Done loading data.')

			self.data = self.data['train'] + self.data['val']
			random.seed(FLAGS.seed)
			random.shuffle(self.data)

			self.cur_task = FLAGS.cur_task
			self.cur_task_batch_id = 1 if self.cur_task == 0 else (self.data[self.cur_task]['images'].shape[0]-40) // self.per_task_batch_size
			if self.cur_task == 0:
				self.num_tasks = 2
			else:
				self.num_tasks = self.cur_task + 1
			self.task_batches = {}
			for i in range(self.num_tasks):
				if self.cur_task == 0 and i == 1:
					self.task_batches[i] = list(range((self.data[self.cur_task]['images'].shape[0]-40) // self.per_task_batch_size))
				else:
					self.task_batches[i] = list(range(self.cur_task_batch_id))
		else:
			raise ValueError('Unrecognized data source')

	def load_miniimagenet(self):
		print('Loading images into RAM')
		self.images = {}
		for tfolder in self.task_folders:
			if 'cont' in FLAGS.datasource:
				image_filepaths = [os.path.join(tfolder, batch, img_name) \
					for batch in os.listdir(tfolder) \
					for img_name in os.listdir(os.path.join(tfolder, str(batch)))]
			else:
				image_filepaths = [os.path.join(tfolder, img_name) \
					for img_name in os.listdir(tfolder)]
			for filepath in image_filepaths:
				self.images[filepath] = load(filepath)
		print('Done loading images')

	def add_task(self):
		# Performance on current task is satisfactory. Move on to next task.
		assert 'cont' in FLAGS.datasource
		self.cur_task += 1
		if self.cur_task >= self.num_tasks:
			self.num_tasks += 1
			self.cur_task_batch_id = 1
			self.task_batches[self.cur_task] = [0]
		else:
			if 'cont_miniimagenet' in FLAGS.datasource and self.cur_task == 1:
				self.cur_task_batch_id = 1
				self.task_batches[self.cur_task] = [0]
			else:
				self.cur_task_batch_id = 10
		task_batches_msg = {key: self.task_batches[key][-1] for key in self.task_batches.keys()}
		print(task_batches_msg)

	def add_batch(self):
		# Need to add more data for the current task
		assert 'cont' in FLAGS.datasource
		self.task_batches[self.cur_task].append(self.cur_task_batch_id)
		self.cur_task_batch_id += 1

	def generate_cont_rainbow_mnist_batch(self, inner_batch_sizes, train=True, train_step=None, test_zero_shot=False):  # RGB images
		# Should advance to the next task once the model has seen the second to the last batch
		assert self.cur_task_batch_id < self.data[self.cur_task]['labels'].shape[0] // self.per_task_batch_size
		if train:
			# Use all tasks so far
			if FLAGS.train_only_on_cur:
				task_data = {self.cur_task: self.data[self.cur_task]}
			else:
				task_data = {i: self.data[i] for i in range(self.num_tasks)}
		else:
			# Only use the next task # current task
			# cont_incl_cur: whether or not to meta-train on the current task
			if FLAGS.cont_incl_cur:
				task_data = {self.cur_task: self.data[self.cur_task]}
			else:
				task_data = {self.cur_task+1: self.data[self.cur_task+1]}


		pre_adapt_images_list, pre_adapt_labels_list, post_adapt_images_list, post_adapt_labels_list = [], [], [], []

		if train:
			random_tasks = np.random.choice(range(len(task_data)), size=FLAGS.meta_train_tasks, replace=True)
		else:
			random_tasks = [self.cur_task for _ in range(FLAGS.meta_train_tasks)]
		num_available_data = self.cur_task_batch_id*self.per_task_batch_size
		if train:
			# Note this is actually not used. We use different shots within a meta-batch right now
			inner_batch_size = inner_batch_sizes[train_step % len(inner_batch_sizes)]
			inner_batch_size = min(inner_batch_size, num_available_data // 2)
		else:
			if test_zero_shot:
				inner_batch_size = 0
			else:
				inner_batch_size = min(np.amax(inner_batch_sizes), num_available_data // 2)

		for i, task_id in enumerate(random_tasks):
			available_task_data = {'images': task_data[task_id]['images'][:len(self.task_batches[task_id])*self.per_task_batch_size],
								   'labels': task_data[task_id]['labels'][:len(self.task_batches[task_id])*self.per_task_batch_size]}
			if train:
				cur_inner_batch_size = np.random.choice(inner_batch_sizes)
				cur_inner_batch_size = min(cur_inner_batch_size, available_task_data['labels'].shape[0] // 2)
				outer_batch_size = min(FLAGS.outer_batch_size, available_task_data['labels'].shape[0] - cur_inner_batch_size)
				random_indices = np.random.choice(
					len(self.task_batches[task_id])*self.per_task_batch_size,
					cur_inner_batch_size + outer_batch_size,
					replace=False
				)
				pre_adapt_indices = random_indices[:cur_inner_batch_size]
				post_adapt_indices = random_indices[cur_inner_batch_size:]

				pre_adapt_images = available_task_data['images'][pre_adapt_indices]
				pre_adapt_labels = available_task_data['labels'][pre_adapt_indices]
				post_adapt_images = available_task_data['images'][post_adapt_indices]
				post_adapt_labels = available_task_data['labels'][post_adapt_indices]
			else:
				try:
					assert available_task_data['labels'].shape[0] == num_available_data
				except:
					import pdb; pdb.set_trace()
				val_task_data = {'images': task_data[task_id]['images'][self.cur_task_batch_id*self.per_task_batch_size:],
							 'labels': task_data[task_id]['labels'][self.cur_task_batch_id*self.per_task_batch_size:]}
				pre_adapt_indices = np.random.choice(
					num_available_data,
					inner_batch_size,
					replace=False
				)
				post_adapt_indices = np.random.choice(
					val_task_data['labels'].shape[0],
					FLAGS.outer_batch_size,
					replace=False
				)
				pre_adapt_images = available_task_data['images'][pre_adapt_indices]
				pre_adapt_labels = available_task_data['labels'][pre_adapt_indices]
				post_adapt_images = val_task_data['images'][post_adapt_indices]
				post_adapt_labels = val_task_data['labels'][post_adapt_indices]
			pre_adapt_images_list.append(pre_adapt_images)
			pre_adapt_labels_list.append(pre_adapt_labels)
			post_adapt_images_list.append(post_adapt_images)
			post_adapt_labels_list.append(post_adapt_labels)
		return pre_adapt_images_list, pre_adapt_labels_list, post_adapt_images_list, post_adapt_labels_list, inner_batch_size


	def generate_cont_pose_batch(self, inner_batch_sizes, train=True, train_step=None, test_zero_shot=False):  # RGB images
		# Should advance to the next task once the model has seen the second to the last batch
		assert self.cur_task_batch_id < self.data[self.cur_task]['labels'].shape[0] // self.per_task_batch_size
		if train:
			# Use all tasks so far
			if FLAGS.train_only_on_cur:
				task_data = {self.cur_task: self.data[self.cur_task]}
			else:
				task_data = {i: self.data[i] for i in range(self.num_tasks)}
		else:
			# Only use the next task # current task
			# cont_incl_cur: whether or not to meta-train on the current task
			if FLAGS.cont_incl_cur:
				task_data = {self.cur_task: self.data[self.cur_task]}
			else:
				task_data = {self.cur_task+1: self.data[self.cur_task+1]}


		pre_adapt_images_list, pre_adapt_labels_list, post_adapt_images_list, post_adapt_labels_list = [], [], [], []

		if train:
			random_tasks = np.random.choice(range(len(task_data)), size=FLAGS.meta_train_tasks, replace=True)
		else:
			random_tasks = [self.cur_task for _ in range(FLAGS.meta_train_tasks)]
		num_available_data = self.cur_task_batch_id*self.per_task_batch_size
		if train:
			# Note this is actually not used. We use different shots within a meta-batch right now
			inner_batch_size = inner_batch_sizes[train_step % len(inner_batch_sizes)]
			inner_batch_size = min(inner_batch_size, num_available_data // 2)
		else:
			if test_zero_shot:
				inner_batch_size = 0
			else:
				inner_batch_size = min(np.amax(inner_batch_sizes), num_available_data // 2)

		for i, task_id in enumerate(random_tasks):
			available_task_data = {'images': task_data[task_id]['images'][:len(self.task_batches[task_id])*self.per_task_batch_size],
								   'labels': task_data[task_id]['labels'][:len(self.task_batches[task_id])*self.per_task_batch_size]}
			if train:
				cur_inner_batch_size = np.random.choice(inner_batch_sizes)
				cur_inner_batch_size = min(cur_inner_batch_size, available_task_data['labels'].shape[0] // 2)
				outer_batch_size = min(FLAGS.outer_batch_size, available_task_data['labels'].shape[0] - cur_inner_batch_size)
				random_indices = np.random.choice(
					len(self.task_batches[task_id])*self.per_task_batch_size,
					cur_inner_batch_size + outer_batch_size,
					replace=False
				)
				pre_adapt_indices = random_indices[:cur_inner_batch_size]
				post_adapt_indices = random_indices[cur_inner_batch_size:]

				pre_adapt_images = available_task_data['images'][pre_adapt_indices]
				pre_adapt_labels = available_task_data['labels'][pre_adapt_indices]
				post_adapt_images = available_task_data['images'][post_adapt_indices]
				post_adapt_labels = available_task_data['labels'][post_adapt_indices]
			else:
				try:
					assert available_task_data['labels'].shape[0] == num_available_data
				except:
					import pdb; pdb.set_trace()
				val_task_data = {'images': task_data[task_id]['images'][self.cur_task_batch_id*self.per_task_batch_size:],
							 'labels': task_data[task_id]['labels'][self.cur_task_batch_id*self.per_task_batch_size:]}
				pre_adapt_indices = np.random.choice(
					num_available_data,
					inner_batch_size,
					replace=False
				)
				post_adapt_indices = np.random.choice(
					val_task_data['labels'].shape[0],
					FLAGS.outer_batch_size,
					replace=False
				)
				pre_adapt_images = available_task_data['images'][pre_adapt_indices]
				pre_adapt_labels = available_task_data['labels'][pre_adapt_indices]
				post_adapt_images = val_task_data['images'][post_adapt_indices]
				post_adapt_labels = val_task_data['labels'][post_adapt_indices]
			pre_adapt_images_list.append(np.expand_dims(pre_adapt_images, axis=-1))
			pre_adapt_labels_list.append(pre_adapt_labels / 10.0)
			post_adapt_images_list.append(np.expand_dims(post_adapt_images, axis=-1))
			post_adapt_labels_list.append(post_adapt_labels / 10.0)
		return pre_adapt_images_list, pre_adapt_labels_list, post_adapt_images_list, post_adapt_labels_list, inner_batch_size

	def generate_cont_miniimagenet_batch(self, inner_batch_sizes, train=True, train_step=None, test_zero_shot=False):  # RGB images
		# Should advance to the next task once the model has seen the second to the last batch
		assert self.cur_task_batch_id < self.data[self.cur_task]['images'].shape[0] // self.per_task_batch_size
		if train:
			# Use all tasks so far
			if FLAGS.train_only_on_cur:
				task_data = {self.cur_task: self.data[self.cur_task]}
			else:
				task_data = {i: self.data[i] for i in range(self.num_tasks)}
		else:
			# Only use the next task # current task
			# cont_incl_cur: whether or not to meta-train on the current task
			if FLAGS.cont_incl_cur:
				task_data = {self.cur_task: self.data[self.cur_task]}
			else:
				task_data = {self.cur_task+1: self.data[self.cur_task+1]}
		other_task_data = {i: self.data[i] for i in range(self.num_tasks)}

		pre_adapt_images_list, pre_adapt_labels_list, post_adapt_images_list, post_adapt_labels_list = [], [], [], []

		if train:
			random_tasks = np.random.choice(range(len(task_data)), size=FLAGS.meta_train_tasks, replace=True)
		else:
			random_tasks = [self.cur_task for _ in range(FLAGS.meta_train_tasks)]
		num_available_data = self.cur_task_batch_id*self.per_task_batch_size
		if train:
			# Note this is actually not used. We use different shots within a meta-batch right now
			inner_batch_size = inner_batch_sizes[train_step % len(inner_batch_sizes)]
			inner_batch_size = min(inner_batch_size, num_available_data // 4)
		else:
			if test_zero_shot:
				inner_batch_size = 0
			else:
				inner_batch_size = min(np.amax(inner_batch_sizes), num_available_data // 4)

		for i, task_id in enumerate(random_tasks):
			other_task_idxes = [idx for idx in range(self.num_tasks) if idx != task_id]
			available_task_data = task_data[task_id]['images'][:len(self.task_batches[task_id])*self.per_task_batch_size]
			if train:
				cur_inner_batch_size = np.random.choice(inner_batch_sizes)
				cur_inner_batch_size = min(cur_inner_batch_size, available_task_data.shape[0] // 4)
				outer_batch_size = min(min(FLAGS.outer_batch_size, self.per_task_batch_size), available_task_data.shape[0] // 2 - cur_inner_batch_size)
				random_indices = np.random.choice(
					len(self.task_batches[task_id])*self.per_task_batch_size,
					int((cur_inner_batch_size + outer_batch_size) * 3 / 2),
					replace=False
				)
				condition_indices = random_indices[:cur_inner_batch_size + outer_batch_size]
				pre_adapt_indices = random_indices[cur_inner_batch_size + outer_batch_size: cur_inner_batch_size + outer_batch_size + cur_inner_batch_size // 2]
				post_adapt_indices = random_indices[cur_inner_batch_size + outer_batch_size + cur_inner_batch_size // 2:]

				condition_images = available_task_data[condition_indices]
				pre_adapt_images = available_task_data[pre_adapt_indices]
				pre_adapt_labels = np.ones([pre_adapt_images.shape[0], 1], dtype=np.int64)
				post_adapt_images = available_task_data[post_adapt_indices]
				post_adapt_labels = np.ones([post_adapt_images.shape[0], 1], dtype=np.int64)

			else:
				try:
					assert available_task_data.shape[0] == num_available_data
				except:
					import pdb; pdb.set_trace()
				val_task_data = task_data[task_id]['images'][self.cur_task_batch_id*self.per_task_batch_size:]
				cur_inner_batch_size = inner_batch_size
				outer_batch_size = FLAGS.outer_batch_size
				pre_adapt_indices = np.random.choice(
					num_available_data,
					cur_inner_batch_size + outer_batch_size + cur_inner_batch_size // 2,
					replace=False
				)
				post_adapt_indices = np.random.choice(
					val_task_data.shape[0],
					outer_batch_size // 2,
					replace=False
				)
				condition_indices = pre_adapt_indices[:cur_inner_batch_size + outer_batch_size]
				pre_adapt_indices = pre_adapt_indices[cur_inner_batch_size + outer_batch_size:]

				condition_images = available_task_data[condition_indices]
				pre_adapt_images = available_task_data[pre_adapt_indices]
				pre_adapt_labels = np.ones([pre_adapt_images.shape[0], 1], dtype=np.int64)
				post_adapt_images = val_task_data[post_adapt_indices]
				post_adapt_labels = np.ones([post_adapt_images.shape[0], 1], dtype=np.int64)

			other_sampled_task_idxes = [np.random.choice(other_task_idxes) for _ in range((cur_inner_batch_size + outer_batch_size) // 2)]
			other_sampled_task_data_idxes = [np.random.choice(len(self.task_batches[idx])*self.per_task_batch_size) for idx in other_sampled_task_idxes]
			other_sampled_task_data = np.stack([other_task_data[other_sampled_task_idxes[i]]['images'][other_sampled_task_data_idxes[i]] for i in range(len(other_sampled_task_idxes))])
			pre_adapt_images_other = other_sampled_task_data[:(cur_inner_batch_size // 2)]
			pre_adapt_labels_other = np.zeros([pre_adapt_images_other.shape[0], 1], dtype=np.int64)
			post_adapt_images_other = other_sampled_task_data[(cur_inner_batch_size // 2):]
			post_adapt_labels_other = np.zeros([post_adapt_images_other.shape[0], 1], dtype=np.int64)

			pre_adapt_images = np.concatenate((pre_adapt_images, pre_adapt_images_other), axis=0)
			pre_adapt_labels = np.concatenate((pre_adapt_labels, pre_adapt_labels_other), axis=0)
			post_adapt_images = np.concatenate((post_adapt_images, post_adapt_images_other), axis=0)
			post_adapt_labels = np.concatenate((post_adapt_labels, post_adapt_labels_other), axis=0)

			pre_adapt_random_indices, post_adapt_random_indices = np.arange(pre_adapt_images.shape[0]), np.arange(post_adapt_images.shape[0])
			np.random.shuffle(pre_adapt_random_indices)
			np.random.shuffle(post_adapt_random_indices)
			pre_adapt_images = pre_adapt_images[pre_adapt_random_indices]
			pre_adapt_labels = pre_adapt_labels[pre_adapt_random_indices]
			post_adapt_images = post_adapt_images[post_adapt_random_indices]
			post_adapt_labels = post_adapt_labels[post_adapt_random_indices]
			try:
				pre_adapt_images = np.concatenate((pre_adapt_images, condition_images[:cur_inner_batch_size]), axis=-1)
				post_adapt_images = np.concatenate((post_adapt_images, condition_images[cur_inner_batch_size:]), axis=-1)
			except:
				import pdb; pdb.set_trace()

			pre_adapt_images_list.append(pre_adapt_images)
			pre_adapt_labels_list.append(pre_adapt_labels)
			post_adapt_images_list.append(post_adapt_images)
			post_adapt_labels_list.append(post_adapt_labels)
		return pre_adapt_images_list, pre_adapt_labels_list, post_adapt_images_list, post_adapt_labels_list, inner_batch_size
