import os
from io import BytesIO
from pathlib import Path

import lmdb
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import LSUNClass
from torchvision.datasets import CIFAR10 as CIFAR10_torch
from torchvision.datasets import CIFAR100 as CIFAR100_torch
import torch
import pandas as pd
import numpy as np

import torchvision.transforms.functional as Ftrans


class ImageDataset(Dataset):
	def __init__(
		self,
		folder,
		image_size,
		exts=['jpg'],
		do_augment: bool = True,
		do_transform: bool = True,
		do_normalize: bool = True,
		sort_names=False,
		has_subdir: bool = True,
	):
		super().__init__()
		self.folder = folder
		self.image_size = image_size

		# relative paths (make it shorter, saves memory and faster to sort)
		if has_subdir:
			self.paths = [
				p.relative_to(folder) for ext in exts
				for p in Path(f'{folder}').glob(f'**/*.{ext}')
			]
		else:
			self.paths = [
				p.relative_to(folder) for ext in exts
				for p in Path(f'{folder}').glob(f'*.{ext}')
			]
		if sort_names:
			self.paths = sorted(self.paths)

		transform = [
			transforms.Resize(image_size),
			transforms.CenterCrop(image_size),
		]
		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if do_transform:
			transform.append(transforms.ToTensor())
		if do_normalize:
			transform.append(
				transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
		self.transform = transforms.Compose(transform)

	def __len__(self):
		return len(self.paths)

	def __getitem__(self, index):
		path = os.path.join(self.folder, self.paths[index])
		img = Image.open(path)
		# if the image is 'rgba'!
		img = img.convert('RGB')
		if self.transform is not None:
			img = self.transform(img)
		return {'img': img, 'index': index}


class SubsetDataset(Dataset):
	def __init__(self, dataset, size):
		assert len(dataset) >= size
		self.dataset = dataset
		self.size = size

	def __len__(self):
		return self.size

	def __getitem__(self, index):
		assert index < self.size
		return self.dataset[index]


class BaseLMDB(Dataset):
	def __init__(self, path, original_resolution, zfill: int = 5):
		self.original_resolution = original_resolution
		self.zfill = zfill
		self.env = lmdb.open(
			path,
			max_readers=32,
			readonly=True,
			lock=False,
			readahead=False,
			meminit=False,
		)

		if not self.env:
			raise IOError('Cannot open lmdb dataset', path)

		with self.env.begin(write=False) as txn:
			self.length = int(
				txn.get('length'.encode('utf-8')).decode('utf-8'))

	def __len__(self):
		return self.length

	def __getitem__(self, index):
		with self.env.begin(write=False) as txn:
			key = f'{self.original_resolution}-{str(index).zfill(self.zfill)}'.encode(
				'utf-8')
			img_bytes = txn.get(key)

		buffer = BytesIO(img_bytes)
		img = Image.open(buffer)
		return img


def make_transform(
	image_size,
	flip_prob=0.5,
	crop_d2c=False,
):
	if crop_d2c:
		transform = [
			d2c_crop(),
			transforms.Resize(image_size),
		]
	else:
		transform = [
			transforms.Resize(image_size),
			transforms.CenterCrop(image_size),
		]
	transform.append(transforms.RandomHorizontalFlip(p=flip_prob))
	transform.append(transforms.ToTensor())
	transform.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
	transform = transforms.Compose(transform)
	return transform


class FFHQlmdb(Dataset):
	def __init__(self,
				path=os.path.expanduser('datasets/ffhq256.lmdb'),
				image_size=256,
				original_resolution=256,
				split=None,
				as_tensor: bool = True,
				do_augment: bool = True,
				do_normalize: bool = True,
				**kwargs):
		self.original_resolution = original_resolution
		self.data = BaseLMDB(path, original_resolution, zfill=5)
		self.length = len(self.data)

		if split is None:
			self.offset = 0
		elif split == 'train':
			# last 60k
			self.length = self.length - 10000
			self.offset = 10000
		elif split == 'test':
			# first 10k
			self.length = 10000
			self.offset = 0
		else:
			raise NotImplementedError()

		transform = [
			transforms.Resize(image_size),
		]
		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if as_tensor:
			transform.append(transforms.ToTensor())
		if do_normalize:
			transform.append(
				transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
		self.transform = transforms.Compose(transform)

	def __len__(self):
		return self.length

	def __getitem__(self, index):
		assert index < self.length
		index = index + self.offset
		img = self.data[index]
		if self.transform is not None:
			img = self.transform(img)
		return {'img': img, 'index': index}


class Crop:
	def __init__(self, x1, x2, y1, y2):
		self.x1 = x1
		self.x2 = x2
		self.y1 = y1
		self.y2 = y2

	def __call__(self, img):
		return Ftrans.crop(img, self.x1, self.y1, self.x2 - self.x1,
						self.y2 - self.y1)

	def __repr__(self):
		return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
			self.x1, self.x2, self.y1, self.y2)


def d2c_crop():
	# from D2C paper for CelebA dataset.
	cx = 89
	cy = 121
	x1 = cy - 64
	x2 = cy + 64
	y1 = cx - 64
	y2 = cx + 64
	return Crop(x1, x2, y1, y2)


class CelebAlmdb(Dataset):
	"""
	also supports for d2c crop.
	"""
	def __init__(self,
				path,
				image_size,
				original_resolution=128,
				split=None,
				as_tensor: bool = True,
				do_augment: bool = True,
				do_normalize: bool = True,
				crop_d2c: bool = False,
				**kwargs):
		self.original_resolution = original_resolution
		self.data = BaseLMDB(path, original_resolution, zfill=7)
		self.length = len(self.data)
		self.crop_d2c = crop_d2c

		if split is None:
			self.offset = 0
		else:
			raise NotImplementedError()

		if crop_d2c:
			transform = [
				d2c_crop(),
				transforms.Resize(image_size),
			]
		else:
			transform = [
				transforms.Resize(image_size),
				transforms.CenterCrop(image_size),
			]

		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if as_tensor:
			transform.append(transforms.ToTensor())
		if do_normalize:
			transform.append(
				transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
		self.transform = transforms.Compose(transform)

	def __len__(self):
		return self.length

	def __getitem__(self, index):
		assert index < self.length
		index = index + self.offset
		img = self.data[index]
		if self.transform is not None:
			img = self.transform(img)
		return {'img': img, 'index': index}


class CIFAR10(Dataset):
	def __init__(self,
				path,
				image_size,
				original_resolution=128,
				split=None,
				as_tensor: bool = True,
				do_augment: bool = True,
				do_normalize: bool = True,
				crop_d2c: bool = False,
				**kwargs):
		self.original_resolution = original_resolution
		self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
		# self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
		# self.data = BaseLMDB(path, original_resolution, zfill=7)
		self.data = CIFAR10_torch(download=True,root='/data/dataset', transform=transforms.Compose([
			# transforms.RandomHorizontalFlip(),
			# transforms.RandomCrop(32, 4),
			transforms.ToTensor(),
			self.normalize,
			]))


		self.length = len(self.data)
		self.crop_d2c = crop_d2c

		if split is None:
			self.offset = 0
		else:
			raise NotImplementedError()

		if crop_d2c:
			transform = [
				d2c_crop(),
				transforms.Resize(image_size),
			]
		else:
			transform = [
				transforms.Resize(image_size),
				transforms.CenterCrop(image_size),
			]

		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if as_tensor:
			transform.append(transforms.ToTensor())
		# if do_normalize:
		# 	transform.append(
		# 		transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
		self.transform = transforms.Compose(transform)
		# consider only totensor
		self.transform = transforms.Compose([transforms.ToTensor(), self.normalize])

	def __len__(self):
		return self.length

	def __getitem__(self, index):
		assert index < self.length
		index = index + self.offset

		img = self.data[index][0]
		target = self.data[index][1]
		r = 0
		return {'img': img, 'target':target, 'index': index, 'type': r}


class CIFAR100(Dataset):
	def __init__(self,
				path,
				image_size,
				original_resolution=128,
				split=None,
				as_tensor: bool = True,
				do_augment: bool = True,
				do_normalize: bool = True,
				crop_d2c: bool = False,
				**kwargs):
		self.original_resolution = original_resolution
		self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
		# self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
		# self.data = BaseLMDB(path, original_resolution, zfill=7)
		self.data = CIFAR100_torch(download=True,root='/data/dataset', transform=transforms.Compose([
			# transforms.RandomHorizontalFlip(),
			# transforms.RandomCrop(32, 4),
			transforms.ToTensor(),
			self.normalize,
			]))

		self.length = len(self.data)
		self.crop_d2c = crop_d2c

		if split is None:
			self.offset = 0
		else:
			raise NotImplementedError()

		if crop_d2c:
			transform = [
				d2c_crop(),
				transforms.Resize(image_size),
			]
		else:
			transform = [
				transforms.Resize(image_size),
				transforms.CenterCrop(image_size),
			]

		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if as_tensor:
			transform.append(transforms.ToTensor())
		# if do_normalize:
		# 	transform.append(
		# 		transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
		self.transform = transforms.Compose(transform)
		# consider only totensor
		self.transform = transforms.Compose([transforms.ToTensor(), self.normalize])

	def __len__(self):
		return self.length

	def __getitem__(self, index):
		assert index < self.length
		index = index + self.offset

		img = self.data[index][0]
		target = self.data[index][1]
		r = 0
		return {'img': img, 'target':target, 'index': index, 'type': r}


class Horse_lmdb(Dataset):
	def __init__(self,
				path=os.path.expanduser('datasets/horse256.lmdb'),
				image_size=128,
				original_resolution=256,
				do_augment: bool = True,
				do_transform: bool = True,
				do_normalize: bool = True,
				**kwargs):
		self.original_resolution = original_resolution
		print(path)
		self.data = BaseLMDB(path, original_resolution, zfill=7)
		self.length = len(self.data)

		transform = [
			transforms.Resize(image_size),
			transforms.CenterCrop(image_size),
		]
		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if do_transform:
			transform.append(transforms.ToTensor())
		if do_normalize:
			transform.append(
				transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
		self.transform = transforms.Compose(transform)

	def __len__(self):
		return self.length

	def __getitem__(self, index):
		img = self.data[index]
		if self.transform is not None:
			img = self.transform(img)
		return {'img': img, 'index': index}


class Bedroom_lmdb(Dataset):
	def __init__(self,
				path=os.path.expanduser('datasets/bedroom256.lmdb'),
				image_size=128,
				original_resolution=256,
				do_augment: bool = True,
				do_transform: bool = True,
				do_normalize: bool = True,
				**kwargs):
		self.original_resolution = original_resolution
		print(path)
		self.data = BaseLMDB(path, original_resolution, zfill=7)
		self.length = len(self.data)

		transform = [
			transforms.Resize(image_size),
			transforms.CenterCrop(image_size),
		]
		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if do_transform:
			transform.append(transforms.ToTensor())
		if do_normalize:
			transform.append(
				transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
		self.transform = transforms.Compose(transform)

	def __len__(self):
		return self.length

	def __getitem__(self, index):
		img = self.data[index]
		img = self.transform(img)
		return {'img': img, 'index': index}


class CelebAttrDataset(Dataset):

	id_to_cls = [
		'5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
		'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
		'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
		'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
		'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
		'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
		'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
		'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
		'Wearing_Necklace', 'Wearing_Necktie', 'Young'
	]
	cls_to_id = {v: k for k, v in enumerate(id_to_cls)}

	def __init__(self,
				folder,
				image_size=64,
				attr_path=os.path.expanduser(
					'datasets/celeba_anno/list_attr_celeba.txt'),
				ext='png',
				only_cls_name: str = None,
				only_cls_value: int = None,
				do_augment: bool = False,
				do_transform: bool = True,
				do_normalize: bool = True,
				d2c: bool = False):
		super().__init__()
		self.folder = folder
		self.image_size = image_size
		self.ext = ext

		# relative paths (make it shorter, saves memory and faster to sort)
		paths = [
			str(p.relative_to(folder))
			for p in Path(f'{folder}').glob(f'**/*.{ext}')
		]
		paths = [str(each).split('.')[0] + '.jpg' for each in paths]

		if d2c:
			transform = [
				d2c_crop(),
				transforms.Resize(image_size),
			]
		else:
			transform = [
				transforms.Resize(image_size),
				transforms.CenterCrop(image_size),
			]
		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if do_transform:
			transform.append(transforms.ToTensor())
		if do_normalize:
			transform.append(
				transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
		self.transform = transforms.Compose(transform)

		with open(attr_path) as f:
			# discard the top line
			f.readline()
			self.df = pd.read_csv(f, delim_whitespace=True)
			self.df = self.df[self.df.index.isin(paths)]

		if only_cls_name is not None:
			self.df = self.df[self.df[only_cls_name] == only_cls_value]

	def pos_count(self, cls_name):
		return (self.df[cls_name] == 1).sum()

	def neg_count(self, cls_name):
		return (self.df[cls_name] == -1).sum()

	def __len__(self):
		return len(self.df)

	def __getitem__(self, index):
		row = self.df.iloc[index]
		name = row.name.split('.')[0]
		name = f'{name}.{self.ext}'

		path = os.path.join(self.folder, name)
		img = Image.open(path)

		labels = [0] * len(self.id_to_cls)
		for k, v in row.items():
			labels[self.cls_to_id[k]] = int(v)

		if self.transform is not None:
			img = self.transform(img)

		return {'img': img, 'index': index, 'labels': torch.tensor(labels)}


class CelebD2CAttrDataset(CelebAttrDataset):
	"""
	the dataset is used in the D2C paper. 
	it has a specific crop from the original CelebA.
	"""
	def __init__(self,
				folder,
				image_size=64,
				attr_path=os.path.expanduser(
					'datasets/celeba_anno/list_attr_celeba.txt'),
				ext='jpg',
				only_cls_name: str = None,
				only_cls_value: int = None,
				do_augment: bool = False,
				do_transform: bool = True,
				do_normalize: bool = True,
				d2c: bool = True):
		super().__init__(folder,
						image_size,
						attr_path,
						ext=ext,
						only_cls_name=only_cls_name,
						only_cls_value=only_cls_value,
						do_augment=do_augment,
						do_transform=do_transform,
						do_normalize=do_normalize,
						d2c=d2c)


class CelebAttrFewshotDataset(Dataset):
	def __init__(
		self,
		cls_name,
		K,
		img_folder,
		img_size=64,
		ext='png',
		seed=0,
		only_cls_name: str = None,
		only_cls_value: int = None,
		all_neg: bool = False,
		do_augment: bool = False,
		do_transform: bool = True,
		do_normalize: bool = True,
		d2c: bool = False,
	) -> None:
		self.cls_name = cls_name
		self.K = K
		self.img_folder = img_folder
		self.ext = ext

		if all_neg:
			path = f'data/celeba_fewshots/K{K}_allneg_{cls_name}_{seed}.csv'
		else:
			path = f'data/celeba_fewshots/K{K}_{cls_name}_{seed}.csv'
		self.df = pd.read_csv(path, index_col=0)
		if only_cls_name is not None:
			self.df = self.df[self.df[only_cls_name] == only_cls_value]

		if d2c:
			transform = [
				d2c_crop(),
				transforms.Resize(img_size),
			]
		else:
			transform = [
				transforms.Resize(img_size),
				transforms.CenterCrop(img_size),
			]
		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if do_transform:
			transform.append(transforms.ToTensor())
		if do_normalize:
			transform.append(
				transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
		self.transform = transforms.Compose(transform)

	def pos_count(self, cls_name):
		return (self.df[cls_name] == 1).sum()

	def neg_count(self, cls_name):
		return (self.df[cls_name] == -1).sum()

	def __len__(self):
		return len(self.df)

	def __getitem__(self, index):
		row = self.df.iloc[index]
		name = row.name.split('.')[0]
		name = f'{name}.{self.ext}'

		path = os.path.join(self.img_folder, name)
		img = Image.open(path)

		# (1, 1)
		label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1)

		if self.transform is not None:
			img = self.transform(img)

		return {'img': img, 'index': index, 'labels': label}


class CelebD2CAttrFewshotDataset(CelebAttrFewshotDataset):
	def __init__(self,
				cls_name,
				K,
				img_folder,
				img_size=64,
				ext='jpg',
				seed=0,
				only_cls_name: str = None,
				only_cls_value: int = None,
				all_neg: bool = False,
				do_augment: bool = False,
				do_transform: bool = True,
				do_normalize: bool = True,
				is_negative=False,
				d2c: bool = True) -> None:
		super().__init__(cls_name,
						K,
						img_folder,
						img_size,
						ext=ext,
						seed=seed,
						only_cls_name=only_cls_name,
						only_cls_value=only_cls_value,
						all_neg=all_neg,
						do_augment=do_augment,
						do_transform=do_transform,
						do_normalize=do_normalize,
						d2c=d2c)
		self.is_negative = is_negative


class CelebHQAttrDataset(Dataset):
	id_to_cls = [
		'5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
		'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
		'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
		'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
		'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
		'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
		'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
		'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
		'Wearing_Necklace', 'Wearing_Necktie', 'Young'
	]
	cls_to_id = {v: k for k, v in enumerate(id_to_cls)}

	def __init__(self,
				path=os.path.expanduser('datasets/celebahq256.lmdb'),
				image_size=None,
				attr_path=os.path.expanduser(
					'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
				original_resolution=256,
				do_augment: bool = False,
				do_transform: bool = True,
				do_normalize: bool = True):
		super().__init__()
		self.image_size = image_size
		self.data = BaseLMDB(path, original_resolution, zfill=5)

		transform = [
			transforms.Resize(image_size),
			transforms.CenterCrop(image_size),
		]
		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if do_transform:
			transform.append(transforms.ToTensor())
		if do_normalize:
			transform.append(
				transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
		self.transform = transforms.Compose(transform)

		with open(attr_path) as f:
			# discard the top line
			f.readline()
			self.df = pd.read_csv(f, delim_whitespace=True)

	def pos_count(self, cls_name):
		return (self.df[cls_name] == 1).sum()

	def neg_count(self, cls_name):
		return (self.df[cls_name] == -1).sum()

	def __len__(self):
		return len(self.df)

	def __getitem__(self, index):
		row = self.df.iloc[index]
		img_name = row.name
		img_idx, ext = img_name.split('.')
		img = self.data[img_idx]

		labels = [0] * len(self.id_to_cls)
		for k, v in row.items():
			labels[self.cls_to_id[k]] = int(v)

		if self.transform is not None:
			img = self.transform(img)
		return {'img': img, 'index': index, 'labels': torch.tensor(labels)}


class CelebHQAttrFewshotDataset(Dataset):
	def __init__(self,
				cls_name,
				K,
				path,
				image_size,
				original_resolution=256,
				do_augment: bool = False,
				do_transform: bool = True,
				do_normalize: bool = True):
		super().__init__()
		self.image_size = image_size
		self.cls_name = cls_name
		self.K = K
		self.data = BaseLMDB(path, original_resolution, zfill=5)

		transform = [
			transforms.Resize(image_size),
			transforms.CenterCrop(image_size),
		]
		if do_augment:
			transform.append(transforms.RandomHorizontalFlip())
		if do_transform:
			transform.append(transforms.ToTensor())
		if do_normalize:
			transform.append(
				transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
		self.transform = transforms.Compose(transform)

		self.df = pd.read_csv(f'data/celebahq_fewshots/K{K}_{cls_name}.csv',
							index_col=0)

	def pos_count(self, cls_name):
		return (self.df[cls_name] == 1).sum()

	def neg_count(self, cls_name):
		return (self.df[cls_name] == -1).sum()

	def __len__(self):
		return len(self.df)

	def __getitem__(self, index):
		row = self.df.iloc[index]
		img_name = row.name
		img_idx, ext = img_name.split('.')
		img = self.data[img_idx]

		# (1, 1)
		label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1)

		if self.transform is not None:
			img = self.transform(img)

		return {'img': img, 'index': index, 'labels': label}


class Repeat(Dataset):
	def __init__(self, dataset, new_len) -> None:
		super().__init__()
		self.dataset = dataset
		self.original_len = len(dataset)
		self.new_len = new_len

	def __len__(self):
		return self.new_len

	def __getitem__(self, index):
		index = index % self.original_len
		return self.dataset[index]
