import torch


class Shuffle:
	def __call__(self, attr):
		shuffle_idx = torch.randperm(attr['xyz'].shape[0])
		for key in attr.keys():
			if key in attr:
				attr[key] = attr[key][shuffle_idx]
		return attr
	
	
class AttrNormalize:
	def __call__(self, attr):
		xyz = attr['xyz']
		centroid = torch.mean(xyz, dim=0)
		xyz = xyz - centroid
		m = torch.amax(torch.linalg.norm(xyz, ord=2, dim=-1))
		attr['xyz'] = xyz / m
		attr['scale'] = attr['scale'] / m
		return attr

class RandomScalePointCloud:
	def __init__(self, scale_low=0.8, scale_high=1.25):
		self.scale_low = scale_low
		self.scale_high = scale_high

	def __call__(self, attr):
		scale = torch.empty(1).uniform_(self.scale_low, self.scale_high)
		attr['xyz'] = attr['xyz'] * scale
		attr['scale'] = attr['scale'] * scale
		return attr


class RandomShiftPointCloud:
	def __init__(self, shift_range=0.1):
		self.shift_range = shift_range

	def __call__(self, attr):
		shift = torch.empty(1, 3).uniform_(-self.shift_range, self.shift_range)
		attr['xyz'] = attr['xyz'] + shift
		return attr


class JitterPointCloud:
	def __init__(self, sigma=0.01, clip=0.05):
		self.sigma = sigma
		self.clip = clip

	def __call__(self, attr):
		noise = torch.clip(torch.randn(*attr['xyz'].shape) * self.sigma, -self.clip, self.clip)
		attr['xyz'] = attr['xyz'] + noise
		return attr


class RandomPointDropout:
	def __init__(self, max_dropout_ratio=0.5):
		self.max_dropout_ratio = max_dropout_ratio

	def __call__(self, attr):
		dropout_ratio =  torch.rand(1) * self.max_dropout_ratio
		drop_idx = torch.where(torch.rand([attr['xyz'].shape[0],]) <= dropout_ratio)[0]
		if len(drop_idx) > 0:
			for key, value in attr.items():
				v = value.clone()
				v[drop_idx] = attr[key][0]
				attr[key] = v
		return attr
	


