import torch
import numpy as np
import torchvision.transforms as transforms


def transpose(x, source='NHWC', target='NCHW'):
	return x.transpose([source.index(d) for d in target]) 

def pad(x, border=4):
	return np.pad(x, [(0, 0), (border, border), (border, border)], mode='reflect')

class RandomPadandCrop(object):
	"""Crop randomly the image.

	Args:
		output_size (tuple or int): Desired output size. If int, square crop
			is made.
	"""

	def __init__(self, output_size):
		assert isinstance(output_size, (int, tuple))
		if isinstance(output_size, int):
			self.output_size = (output_size, output_size)
		else:
			assert len(output_size) == 2
			self.output_size = output_size

	def __call__(self, x):
		x = pad(x, 4)

		h, w = x.shape[1:]
		new_h, new_w = self.output_size

		top = np.random.randint(0, h - new_h)
		left = np.random.randint(0, w - new_w)

		x = x[:, top: top + new_h, left: left + new_w]

		return x

class RandomFlip(object):
	"""Flip randomly the image.
	"""
	def __call__(self, x):
		if np.random.rand() < 0.5:
			x = x[:, :, ::-1]

		return x.copy()


class ToTensor(object):
	"""Transform the image to tensor.
	"""
	def __call__(self, x):
		x = torch.from_numpy(x)
		return x


def addGaussianNoise(x_train, ratio=0.01, scale=1):
	eps = torch.from_numpy(np.random.normal(scale=scale, size=x_train.shape)).cuda()
	return torch.clamp(x_train + eps * ratio, 0, 1).float()

def randomPadandCrop(x_train, output_size=32):
	# if np.random.rand() < 0.5:
	x_train = np.array(x_train.detach().cpu())
	N = x_train.shape[0]
	x_r = []
	new_h, new_w = output_size, output_size
	for i in range(N):
		x = pad(x_train[0], 4)
		h, w = x.shape[1:]
		top = np.random.randint(0, h - new_h)
		left = np.random.randint(0, w - new_w)
		x = x[:, top: top + new_h, left: left + new_w]
		x_r.append(torch.from_numpy(x))

	b = torch.Tensor(N, 3, output_size, output_size)
	x_r = torch.cat(x_r, out=b)
	x_r = x_r.reshape(N, 3, output_size, output_size)
	return x_r.cuda()

def flip(x_train):
	x_train = np.array(x_train.detach().cpu())
	x = x_train[:, :, :, ::-1]
	return torch.from_numpy(x.copy()).cuda()

def Brightness(x_unlabel, sd=1, bmin=0.001, bmax=0.3):
	# range 0-2 (1 is original image, 0 black)
	brightness_factor = get_truncated_normal(mean=(bmin+bmax)/2, sd=sd, low=bmin, upp=bmax).rvs()
	
	N_unlabel = len(x_unlabel)
	x_unlabel = torch.from_numpy(x_unlabel).reshape(N_unlabel, 1, 16, 16)

	x_unlabel_r = []

	for i in range(N_unlabel):
		#NOTE: input tensor should be in [.., 1or3, H, W] format
		# import pdb;pdb.set_trace()
		x_r = TF.adjust_brightness(x_unlabel[i], brightness_factor)
		x_unlabel_r.append(x_r)

	b = torch.Tensor(N_unlabel, 1, 16, 16)
	x_unlabel_r = torch.cat(x_unlabel_r, out=b)
	x_unlabel_r = x_unlabel_r.view(N_unlabel, -1)
	return x_unlabel_r

from scipy.ndimage.filters import gaussian_filter
from scipy.stats import truncnorm


def get_truncated_normal(mean=0, sd=1, low=0, upp=10):
	return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)

def GaussianBlur(x_unlabel, mean=1, scale=1):
	'''
	x_unlabel should be numpy array
	'''
	# bsd = np.random.normal(loc=1, scale=scale)
	bsd = get_truncated_normal(mean=mean, sd=scale, low=0, upp=100).rvs()
	N_unlabel = len(x_unlabel)
	x_unlabel = x_unlabel.reshape(N_unlabel, 16, 16)

	x_unlabel_r = []

	for i in range(N_unlabel):
		# import pdb;pdb.set_trace()
		x_r = gaussian_filter(x_unlabel[i], sigma=bsd)
		x_unlabel_r.append(x_r)

	x_unlabel_r = np.stack(x_unlabel_r, axis=0)
	# import pdb;pdb.set_trace()
	x_unlabel_r = x_unlabel_r.reshape(N_unlabel, -1)

	return x_unlabel_r