"""
Utility functions for corrupting data for training robust models
"""
import numpy as np
import torch
import os
from skimage.util import random_noise
from skimage import io
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


class CorruptDataset(Dataset):
	"""
	Corrupt Dataset Class
	Takes a clean dataset D and returns a corrupted version of the dataset
		If |D| = G, frac_corrupt = psi then:
		This function returns a corrupted dataset D' such that |D'| = |D| + |B| where, the corrupt samples -
		|B| = alpha * G = \frac{psi}{1 - psi} * G
	"""
	
	def __init__(
			self,
			clean_data: Dataset,
			# corruption config ----
			psi: float = 0.1,
			extend_dataset: bool = False,
			corruption: str = 's&p',
			attack_mode: str = 'un_coordinated',
			noise_amount: float = 0.7,
			noise_mean: float = 1,
			noise_var: float = 10,
			# batch processing params ---
			batch_size: int = None,
			num_workers: int = 1
	):
		"""
		Constructor for Corrupt Dataset Class
		"""
		# apply augmentation to clean data
		self.corruption_cfg = {
			"psi": psi,
			"extend_dataset": extend_dataset,
			"corruption": corruption,
			"attack_mode": attack_mode,
			'noise_amount': noise_amount,
			'noise_mean': noise_mean,
			'noise_var': noise_var,
		}
		print(self.corruption_cfg)
		clean_dataloader = DataLoader(
			clean_data,
			batch_size=int(clean_data.__len__() / 50) if batch_size is None else batch_size,
			shuffle=True,
			pin_memory=True,
			num_workers=min(num_workers, os.cpu_count()),
		)
		
		if extend_dataset:
			print("Expanding Dataset = G + ( B = alpha * G )")
			self.noisy_feats, self.noisy_labels = self.get_corrupted_data(clean_dataloader=clean_dataloader)
			# combine clean and corrupt samples
			for batch in clean_dataloader:
				feats, labels = batch
				self.noisy_feats.extend(feats)
				self.noisy_labels.extend(labels)
		else:
			# It will replace a fraction of the clean samples with corrupt samples: DEFAULT
			self.noisy_feats, self.noisy_labels = self.get_corrupted_data(clean_dataloader=clean_dataloader)
	
	def __len__(self):
		"""
		Length of the dataset
		"""
		return len(self.noisy_feats)
	
	def __getitem__(self, idx):
		return self.noisy_feats[idx], self.noisy_labels[idx]
	
	def get_corrupted_data(self, clean_dataloader: DataLoader):
		"""
		creates corrupt samples from each marginal
		=======================================================================================================
		@param clean_dataloader: DataLoader, clean data dataloader
		
		"""
		corrupt_feats = []
		corrupt_labels = []
		
		for batch in tqdm(clean_dataloader, desc='Corrupting Data ...'):
			feats, targets = batch
			unique_labels = torch.unique(targets)
			for label in unique_labels:
				# Create corrupt samples and labels for each class
				label_indices = torch.where(targets == label)[0]
				feats_given_label = feats[label_indices]
				corrupt_x, corrupt_y = corrupt_image(
					phi=feats_given_label,
					label=label,
					corruption_cfg=self.corruption_cfg
				)
				corrupt_feats.extend(corrupt_x)
				corrupt_labels.extend(corrupt_y)
		
		return corrupt_feats, corrupt_labels


def corrupt_image(
		phi: torch.tensor,
		label: int = None,
		corruption_cfg: dict = None
) -> [np.ndarray, np.ndarray]:
	"""
	Apply corruption to samples from a marginal distribution
	===========================================================
	
	:param phi: torch.tensor,
	input data from same distribution / class ~ p(x|y=i)
	
	:param label: int,
	class label of the marginal distribution y=i
	
	:param corruption_cfg: dict,
	{
		"psi": float, fraction of data to corrupt
		"extend_dataset": bool, whether to extend the dataset or replace samples
		"attack_location": str, location of attack
		"corruption": str, type of corruption to apply
		"attack_mode": str, type of attack to apply
		"noise_amount": float, amount of noise to add for 's&p', 'salt', 'pepper'
		"noise_mean": float, mean of noise to add for 'gaussian', 'speckle'
		"noise_var": float, variance of noise to add for 'gaussian', 'speckle'
	}
	
	:return:
	corrupt samples, corrupt labels
	
	"""
	psi = corruption_cfg.get('psi', 0.4)
	
	# ==========================================
	# Alpha Corruption
	# ==========================================
	extend_dataset = corruption_cfg.get('extend_dataset', False)
	if extend_dataset:
		corrupt_samples = []
		corrupt_labels = []
		# Extend the dataset with corrupt samples
		alpha = psi / (1 - psi)
		max_adv = int(alpha * phi.shape[0])
	
	else:
		# n_B = psi * n_G
		max_adv = int(max(1, psi * phi.shape[0]))
		# to be replaced
		corrupt_samples = phi
		corrupt_labels = [label] * phi.shape[0]
	
	attack_mode = corruption_cfg.get('attack_mode', 'un_coordinated')
	attack_location = corruption_cfg.get('attack_location')
	
	# Randomly select samples to corrupt, each with an independently sampled corruption
	target_indices = np.random.choice(phi.shape[0], max_adv, replace=False)
	
	if attack_mode == 'un_coordinated':
		
		if attack_location == 'sample':
			for ix in target_indices:
				adv_sample = get_corrupt_image(
					img=phi[ix, :],
					corruption_config=corruption_cfg
				)
				if extend_dataset:
					corrupt_samples.append(adv_sample)
					corrupt_labels.append(label)
				else:
					corrupt_samples[ix, :] = adv_sample
					corrupt_labels[ix] = label
		
		elif attack_location == 'feature':
			adv_samples = torch.randn(max_adv, phi.shape[1])
			if extend_dataset:
				corrupt_samples = adv_samples
				corrupt_labels = [label] * max_adv
			else:
				corrupt_samples[target_indices, :] = adv_samples
				corrupt_labels[target_indices] = label
		
		else:
			raise ValueError(f"Attack Location {attack_location} not supported")
	
	elif attack_mode == 'coordinated':
		# Each corrupt sample is set to the same corrupt image => targeted attack
		if attack_location == 'sample':
			target_ix = np.random.choice(phi.shape[0], replace=False)
			adv_sample = get_corrupt_image(
				img=phi[target_ix, :]
			)
		elif attack_location == 'feature':
			adv_sample = torch.rand(1, phi.shape[1])
		else:
			raise ValueError(f"Attack Location {attack_location} not supported")
		
		if extend_dataset:
			corrupt_samples = [adv_sample] * max_adv
			corrupt_labels = [label] * max_adv
		else:
			corrupt_samples[target_indices, :] = adv_sample
			corrupt_labels[target_indices] = label
	
	else:
		raise ValueError(f"Attack mode {attack_mode} not supported")
	
	return corrupt_samples, corrupt_labels


def get_corrupt_image(
		img: torch.tensor,
		corruption_config=None,
		**kwargs
) -> [torch.tensor, int]:
	"""
	Corrupt an image with a specific corruption type
	==================================================
	
	:param img: Image to corrupt
	
	:param corruption_config: dict, optional
	{
		"corruption": str, type of corruption to apply
		"amount": float, amount of noise to add for 's&p', 'salt', 'pepper'
		"mean": float, mean of noise to add for 'gaussian', 'speckle'
		"var": float, variance of noise to add for 'gaussian', 'speckle'
	}
	
	:return: Corrupted sample, Corrupted label
	
	"""
	# Default Corruption Config
	corruption_config = {} if corruption_config is None else corruption_config
	corruption = corruption_config.get('corruption', 's&p')
	amount = corruption_config.get('amount', 0.7)
	mean = corruption_config.get('mean', 1)
	var = corruption_config.get('var', 10)
	
	scaling = img.max().cpu().detach().numpy()
	
	# ==========================================
	# Corruption
	# ==========================================
	if corruption in ['s&p', 'salt', 'pepper']:
		corr_img = random_noise(
			img / scaling,
			mode=corruption,
			amount=amount
		)
		corr_img = np.clip(corr_img, 0, 1) * scaling
	
	elif corruption in ['add-gaussian', 'add-speckle']:
		corr_img = random_noise(
			img / scaling,
			mode=corruption,
			mean=mean,
			var=var
		)
		corr_img = np.clip(corr_img, 0, 1) * scaling
	
	elif corruption == 'poisson':
		corr_img = random_noise(
			img / scaling,
			mode=corruption
		)
		corr_img = np.clip(corr_img, 0, 1) * scaling
	
	else:
		raise ValueError(f"Corruption {corruption} not supported")
	
	corr_img = torch.from_numpy(corr_img).float()
	
	return corr_img


if __name__ == '__main__':
	sample_im = io.imread('plots/corruption_demo/demo.jpeg')
	noisy_im = get_corrupt_image(
		img=sample_im,
		corruption='s&p'
	)
