import os
import argparse
from abc import ABC, abstractmethod
import random

import imageio
import numpy as np
import cv2

from tqdm import tqdm


class DataSet(ABC):

	def __init__(self, base_dir=None, extras=None):
		super().__init__()
		self._base_dir = base_dir
		self._extras = extras

	@abstractmethod
	def read(self):
		pass


class AFHQ(DataSet):

	def __init__(self, base_dir, extras):
		super().__init__(base_dir)

		parser = argparse.ArgumentParser()
		parser.add_argument('-sp', '--split', type=str, choices=['train', 'val'], required=True)
		parser.add_argument('-is', '--img-size', type=int, default=256)

		args = parser.parse_args(extras)
		self.__dict__.update(vars(args))

	def read(self):
		class_ids = sorted(os.listdir(os.path.join(self._base_dir, self.split)))

		imgs = []
		classes = []

		for i, class_id in enumerate(class_ids):
			class_dir = os.path.join(self._base_dir, self.split, class_id)
			class_paths = [os.path.join(class_dir, f) for f in os.listdir(class_dir)]
			imgs.append(np.stack([cv2.resize(imageio.imread(f), dsize=(self.img_size, self.img_size)) for f in class_paths], axis=0))
			classes.append(np.full((len(class_paths), ), fill_value=i, dtype=np.uint32))

		return {
			'img': np.concatenate(imgs, axis=0),
			'class': np.concatenate(classes, axis=0)
		}


class CelebA(DataSet):

	def __init__(self, base_dir, extras):
		super().__init__(base_dir, extras)

		parser = argparse.ArgumentParser()
		parser.add_argument('-cs', '--crop-size', type=int, nargs=2, default=(128, 128))
		parser.add_argument('-ts', '--target-size', type=int, nargs=2, default=(128, 128))
		parser.add_argument('-ni', '--n-identities', type=int, required=False)

		args = parser.parse_args(extras)
		self.__dict__.update(vars(args))

		self.__imgs_dir = os.path.join(self._base_dir, 'Img', 'img_align_celeba_png.7z', 'img_align_celeba_png')
		self.__identity_map_path = os.path.join(self._base_dir, 'Anno', 'identity_CelebA.txt')
		self.__attribute_map_path = os.path.join(self._base_dir, 'Anno', 'list_attr_celeba.txt')

	def __list_imgs(self):
		with open(self.__identity_map_path, 'r') as fd:
			lines = fd.read().splitlines()

		img_paths = []
		identities = []

		for line in lines:
			img_name, identity = line.split(' ')
			img_path = os.path.join(self.__imgs_dir, os.path.splitext(img_name)[0] + '.png')

			img_paths.append(img_path)
			identities.append(identity)

		return img_paths, identities

	def __list_attributes(self):
		with open(self.__attribute_map_path, 'r') as fd:
			lines = fd.read().splitlines()[2:]

		attributes = dict()
		for line in lines:
			tokens = line.split()
			img_name = os.path.splitext(tokens[0])[0]
			img_attributes = np.array(list(map(int, tokens[1:])))
			img_attributes[img_attributes == -1] = 0
			attributes[img_name] = img_attributes

		return attributes

	def read(self):
		img_paths, identity_ids = self.__list_imgs()
		attritbute_map = self.__list_attributes()

		unique_identities = list(set(identity_ids))

		if self.n_identities:
			unique_identities = random.sample(unique_identities, k=self.n_identities)
			img_paths, identity_ids = zip(*[(path, identity) for path, identity in zip(img_paths, identity_ids) if identity in unique_identities])

		imgs = np.empty(shape=(len(img_paths), self.target_size[0], self.target_size[1], 3), dtype=np.uint8)
		identities = np.empty(shape=(len(img_paths), ), dtype=np.int32)
		attributes = np.empty(shape=(len(img_paths), 40), dtype=np.int8)

		for i in tqdm(range(len(img_paths))):
			img_name = os.path.splitext(os.path.basename(img_paths[i]))[0]
			img = imageio.imread(img_paths[i])

			img = img[
				(img.shape[0] // 2 - self.crop_size[0] // 2):(img.shape[0] // 2 + self.crop_size[0] // 2),
				(img.shape[1] // 2 - self.crop_size[1] // 2):(img.shape[1] // 2 + self.crop_size[1] // 2)
			]

			imgs[i] = cv2.resize(img, dsize=tuple(self.target_size))
			identities[i] = unique_identities.index(identity_ids[i])
			attributes[i] = attritbute_map[img_name]

		return {
			'img': imgs,
			'class': identities,
			'attributes': attributes
		}


supported_datasets = {
	'afhq': AFHQ,
	'celebahq': AFHQ,  # same structure
	'celeba': CelebA
}
