import os, mi_utils, torchvision
import json, PIL, time, random
import torch, math, cv2

import numpy as np
import pandas as pd
from PIL import Image
import torch.nn.functional as F 
import torch.utils.data as data
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss
from matplotlib import pyplot
from torch.utils.data.sampler import SubsetRandomSampler


class ImageFolder(data.Dataset):
	def __init__(self, args, file_path, mode):
		self.args = args
		self.mode = mode
		self.data_name = args['dataset']['name']
		# self.img_path = args["dataset"]["img_path"]
		self.model_name = args["dataset"]["model_name"]
		self.processor = self.get_processor()
		t1 = time.time()
		self.name_list, self.label_list = self.get_list(file_path) 
		t2 = time.time()
		print("get list time:", t2-t1)
		self.image_list = self.load_img()
		t3 = time.time()
		print("load image time:", t3-t2)
		self.num_img = len(self.image_list)
		if self.mode is not "gan":
			self.n_classes = args["dataset"]["n_classes"]
			print("Load " + str(self.num_img) + " images")

	
	def get_list(self, file_path):
		name_list, label_list = [], []
		f = open(file_path, "r")
		for line in f.readlines():
			if self.mode == "gan":
				img_name = line.strip()
			else:
				img_name, iden = line.strip().rsplit(' ', 1)
				label_list.append(int(iden))
			name_list.append(img_name)	

		return name_list, label_list

	
	def load_img(self):
		img_list = []
		for i, img_name in enumerate(self.name_list):
			if img_name.endswith(".png"):
				# import pdb; pdb.set_trace()
				path = img_name  #NOTE
				img = PIL.Image.open(path)
				img = img.convert('RGB')
				img_list.append(img)
		return img_list
	
	
	def get_processor(self):
		if self.data_name == 'tsrd':
			re_size = 32
		elif self.data_name == 'celeba':
			re_size = 128
			
		if self.data_name == 'celeba':
			crop_size = 108
			offset_height = (218 - crop_size) // 2
			offset_width = (178 - crop_size) // 2
		elif self.data_name == 'facescrub':
			crop_size = 762
			offset_height = (762 - crop_size) // 2 
			offset_width = (762 - crop_size) // 2
		elif self.data_name == 'pubfig83':
			crop_size = 70
			offset_height = (100 - crop_size) // 2
			offset_width = (100 - crop_size) // 2
		elif self.data_name == 'ffhq':
			# crop_size = 670 # this is for 1024*1024 version
			# offset_height = (1024 - crop_size) // 2
			# offset_width = (1024 - crop_size) // 2
			crop_size = 82 #this is for thumbnail version
			offset_height = (128 - crop_size) // 2
			offset_width = (128 - crop_size) // 2
		elif self.data_name == 'tsrd':
			crop_size = 32
			offset_height = (32 - crop_size) // 2
			offset_width = (32 - crop_size) // 2 #NOTE: no crop
		else:
			print('Dataset is not supported!')
			exit()
		
		crop = lambda x: x[:, offset_height:offset_height + crop_size, offset_width:offset_width + crop_size]

		proc = []
		if self.mode == "train":
			proc.append(transforms.ToTensor())
			proc.append(transforms.Lambda(crop))
			proc.append(transforms.ToPILImage())
			proc.append(transforms.Resize((re_size, re_size)))
			proc.append(transforms.RandomHorizontalFlip(p=0.5))
			proc.append(transforms.ToTensor())
		else:
			proc.append(transforms.ToTensor())
			proc.append(transforms.Lambda(crop))
			proc.append(transforms.ToPILImage())
			proc.append(transforms.Resize((re_size, re_size)))
			proc.append(transforms.ToTensor())
		
			
		return transforms.Compose(proc)

	def __getitem__(self, index):
		processer = self.get_processor()
		img = processer(self.image_list[index])
		if self.mode == "gan":
			return img
		label = self.label_list[index]

		return img, label

	def __len__(self):
		return self.num_img


if __name__ == "__main__":
	print("ok")