# coding: utf-8

import os.path,glob,json
import numpy as np
import torch
from torchvision import datasets as img_datasets
# from torchaudio import datasets as audio_datasets
from .librispeech import LibriSpeech
from .timit import TIMIT

def get_audio_dataset(data_name, data_root, split, max_length=None):
	if data_name=='LibriSpeech':
		dataset = LibriSpeech(data_root, split, max_length=max_length)
	elif data_name=='TIMIT':
		dataset = TIMIT(data_root, split)
	return dataset

def get_img_dataset(data_name, data_root, split, transform=None, target_transform=None, preprocessed=False):
	if data_name=='MNIST' and os.path.basename(data_root.rstrip('/'))=='MNIST':
		data_root = os.path.dirname(data_root.rstrip('/'))
	kwargs = dict(root=data_root, transform=transform)
	if data_name in ['MNIST','CIFAR']:
		kwargs['train'] = split=='train'
	else:
		kwargs['split'] = split
	if preprocessed:
		dataset = {'ImageNet':ImageNetNumpy
					}[data_name](**kwargs)
	else:
		dataset = getattr(img_datasets, data_name)(**kwargs)
	if not target_transform is None:
		dataset = DatasetWTargetTransform(dataset, target_transform)
	return dataset

class DatasetWTargetTransform(object):
	def __init__(self, dataset, transform):
		self.dataset = dataset
		self.transform = transform

	def __getitem__(self, idx):
		img,label = self.dataset[idx]
		target = self.transform(img)
		return img,target,label
	
	def __len__(self):
		return len(self.dataset)


def fast_npy_load(file):
	"""
	Borrowed from: https://github.com/divideconcept/fastnumpyio
	"""
	if type(file) == str:
		file=open(file,"rb")
	header = file.read(128)
	if not header:
		return None
	descr = str(header[19:25], 'utf-8').replace("'","").replace(" ","")
	shape = tuple(int(num) for num in str(header[60:120], 'utf-8').replace(',)', ')').replace(', }', '').replace('(', '').replace(')', '').split(','))
	datasize = np.lib.format.descr_to_dtype(descr).itemsize
	for dimension in shape:
		datasize *= dimension
	return np.ndarray(shape, dtype=descr, buffer=file.read(datasize))

class ImageNetNumpy(object):
	def __init__(self, root, split, transform):
		self.root = root
		self.split = split
		self.transform = transform
		self.samples = sorted(glob.glob(os.path.join(self.split_folder, '**/*.npy'), recursive=True))
		with open(os.path.join(root, 'ann.json'), 'r') as f:
			ann = json.load(f)
		class_to_idx = {'|'.join(cls_tuple):idx for idx,cls_tuple in enumerate(ann['classes'])}
		self.wnid_to_idx = {wnid:class_to_idx['|'.join(cls_tuple)] for wnid,cls_tuple
							in ann['wnid_to_classes'].items()}

	@property
	def split_folder(self) -> str:
		return os.path.join(self.root, self.split)
	
	def __getitem__(self, idx):
		path = self.samples[idx]
		img = fast_npy_load(path)
		img = torch.from_numpy(img)

		# wnid = os.path.basename(path).split('_')[0] # <- this only works for train data.
		wnid = os.path.basename(os.path.dirname(path).rstrip('/'))
		target = self.wnid_to_idx[wnid]
		return img,target
	
	def __len__(self):
		return len(self.samples)
	
