"""
data loder for loading data
"""
import os
import math

from sklearn.model_selection import train_test_split
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import struct
from collections import Counter
from synth_dataset import SynthCIFAR10, SynthSVHN


__all__ = ["DataLoader", "PartDataLoader"]


class DataLoader(object):
	"""
	data loader for CV data sets
	"""
	
	def __init__(self, dataset, batch_size, n_threads=4,
	            data_path='./', logger=None, 
                 data_num=25000, manual_seed=123, datacenter=False, model_name=None):
		"""
		create data loader for specific data set
		:params n_treads: number of threads to load data, default: 4
		:params ten_crop: use ten crop for testing, default: False
		:params data_path: path to data set, default: /home/dataset/
		"""
		self.dataset = dataset
		self.batch_size = batch_size
		self.n_threads = n_threads
		self.data_path = data_path
		self.logger = logger
		self.dataset_root = data_path
		self.dc = datacenter #running on slurm
		self.model_name = model_name	

		self.data_num = data_num
		self.manual_seed = manual_seed
		
		self.logger.info("|===>Creating data loader for " + self.dataset)
		
		if self.dataset in ["cifar10", "svhn"]:
			self.train_loader, self.test_loader = self.cifar(
				dataset=self.dataset)

		elif "synth" in self.dataset:
			self.train_loader, self.test_loader = self.synth_dataset(
				dataset=self.dataset
			)
		else:
			assert False, "invalid data set"
	
	def getloader(self):
		"""
		get train_loader and test_loader
		"""
		return self.train_loader, self.test_loader

	def get_len(self):
		return self.test_len

	def synth_dataset(self, dataset="synth_svhn"):
  
		transform_set = transforms.Compose([transforms.ToTensor()])

		train_dataset_list = []
		single_dataset_num = self.data_num//len(self.dataset_root)
		for data_root in self.dataset_root:
			data_root = data_root.replace("model_name", self.model_name)
			if self.dc: train_data_root = data_root.replace("datasets", "scratch")
			else: train_data_root = data_root
			single_dataset = torch.load(train_data_root)

			if single_dataset_num < len(single_dataset): #slice
				single_dataset.data = single_dataset.data[:single_dataset_num]
				single_dataset.targets = single_dataset.targets[:single_dataset_num]
			train_dataset_list.append(single_dataset)
			print("extracting {} from {}".format(len(single_dataset.data),data_root))
		train_dataset = torch.utils.data.ConcatDataset(train_dataset_list)


		if dataset == "synth_cifar10":
			test_data_root = "/datasets/cifar10"
			test_dataset = dsets.CIFAR10(root=test_data_root,
										train=False, #validation data
										transform=transform_set,
										download=True)
		elif dataset == "synth_svhn":
			test_data_root = "/datasets/svhn"
			test_dataset = dsets.SVHN(root=test_data_root,
										split='test', #validation data
										transform=transform_set,
										download=True)

		self.train_len = len(train_dataset)
		self.test_len = len(test_dataset)

		train_loader = torch.utils.data.DataLoader(train_dataset,
									batch_size=self.batch_size,
									shuffle=True,
									num_workers=self.n_threads)
		labels_all = []
		for x,y in train_loader:
			labels_all.extend(y)
		print(torch.unique(torch.tensor(labels_all), return_counts=True))

		test_loader = torch.utils.data.DataLoader(test_dataset,
									batch_size=self.batch_size,
									shuffle=False,
									num_workers=self.n_threads)

		return train_loader, test_loader
 

	def cifar(self, dataset="cifar10"):
		"""
		dataset: cifar
		"""

		test_data_root = self.dataset_root

		test_transform = transforms.Compose([transforms.ToTensor()])

		if self.dataset == "cifar10":
			test_data_root = "/datasets/cifar10"
			test_dataset = dsets.CIFAR10(root=test_data_root,
			                             train=False,
			                             transform=test_transform)
		elif self.dataset == "svhn":
			test_data_root = "/datasets/svhn"
			if self.dc: test_data_root = "~/svhn"
			test_dataset = dsets.SVHN(root=test_data_root,
							split='test', #validation data
							transform=test_transform,
							download=True)
		else:
			assert False, "invalid data set"


		self.test_len = len(test_dataset)
		test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
												  batch_size=self.batch_size,
												  shuffle=False,
												  pin_memory=True,
												  num_workers=self.n_threads)
		return None, test_loader

	

