"""
(Combinatorial) Data Pruning over Embeddings.
"""
import math
import os
from collections import Counter
from collections import defaultdict

import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from src_robust_estimate import robust_mean_estimate
from utils_corruption import corrupt_image


# ===============================================================
# ---------------------- RobustNess Strategies ------------------
# ===============================================================
def apply_robust_filter(
		phi: torch.Tensor,
		coreset_size: int,
		filter_fraction: float = 1,
		dist_measure: str = 'L2',
):
	"""
	( Proposed )
	Byzantine Robust Filter
	-------------------------
	Two-step method:
	1. Filter : Robust Mean Estimation and Keep the samples closest to the robust mean estimate.
	2. Prune : Apply Base Pruning Strategy on the filtered dataset.

	@:param phi: torch.Tensor of size N x D
	@:param coreset_size: int, size of the pruned dataset (coreset)

	@return selected_indices: torch.Tensor list of indices of selected samples
	"""
	# robust mean estimation
	if filter_fraction == 1:
		return phi
	
	# how many samples to retain as robust
	filter_size = max(
		coreset_size,
		round(filter_fraction * phi.shape[0])
	)
	
	robust_class_center = robust_mean_estimate(
		data=phi.cpu().numpy(),
		estimator='geo_med',
		eps=1e-6,
		max_iter=1000
	)
	# compute dist from GM
	robust_class_center = torch.tensor(robust_class_center, dtype=phi.dtype, device=phi.device)
	
	if dist_measure == 'cosine':
		# normalize and compute distances
		phi_normalized = torch.nn.functional.normalize(phi, p=2, dim=1)
		robust_class_center_normalized = torch.nn.functional.normalize(robust_class_center, p=2, dim=0)
		distances = torch.matmul(phi_normalized, robust_class_center_normalized)
		robust_indices = torch.argsort(distances, descending=True)[:filter_size]
	
	elif dist_measure == 'L2':
		distances = torch.norm(phi - robust_class_center, dim=1)
		robust_indices = torch.argsort(distances, descending=False)[:filter_size]
	
	else:
		raise NotImplementedError
	
	return robust_indices


def compute_center(target_moment: str, phi: torch.Tensor) -> torch.Tensor:
	"""
	Compute the Center of the Class
	--------------------------------
	:param target_moment: str, 'mean' / 'geo_med' / 'co_med'
	:param phi: torch.Tensor of size N x D (N samples in D-dimensional space)
	"""
	# print("Computing Center of the Class using {}".format(target_moment))
	if target_moment == 'mean':
		class_center = torch.mean(phi, dim=0)
	
	elif target_moment == 'geo_med':
		class_center = robust_mean_estimate(
			data=phi.cpu().numpy(),
			estimator='geo_med',
			eps=1e-5,
			max_iter=1000
		)
		class_center = torch.tensor(class_center, dtype=phi.dtype, device=phi.device)
	
	elif target_moment == 'interior_geo_med':
		pot_good_points = apply_robust_filter(phi=phi, coreset_size=phi.size(0), filter_fraction=0.5)
		class_center = robust_mean_estimate(
			data=phi[pot_good_points].cpu().numpy(),
			estimator='geo_med',
			eps=1e-5,
			max_iter=1000
		)
		class_center = torch.tensor(class_center, dtype=phi.dtype, device=phi.device)
	else:
		raise NotImplementedError
	
	return class_center


# ===============================================================
# Geometric Pruning Strategies
# ===============================================================
def moderate(
		phi: torch.Tensor,
		coreset_size: int,
		target_moment='mean',  # mean, geo_med, co_med
		**kwargs
) -> torch.Tensor:
	"""
	
	Moderate Coreset Sampling
	-------------------------
	Selects samples around the median distance from the class center.
	Reference:
	https://openreview.net/pdf?id=7D5EECbOaf9
	MODERATE CORESET: A UNIVERSAL METHOD OF DATA SELECTION FOR REAL-WORLD DATA-EFFICIENT DEEP LEARNING; ICLR 2023

	@param phi:
	torch.Tensor of size N x D from p(x | y = j)
	
	@param coreset_size: int
	size of the pruned dataset (coreset)
	
	@param target_moment: str

	@return selected_indices:
	torch.Tensor list of indices of selected samples
	
	"""
	
	# compute the center of the class [Eq. 1]
	class_center = compute_center(target_moment, phi)
	# select samples around the sample with median distance from the center [Eq. 2]
	distances = torch.norm(phi - class_center, dim=1)
	low_idx = round(0.5 * distances.shape[0] - coreset_size / 2)
	high_idx = round(0.5 * distances.shape[0] + coreset_size / 2)
	sorted_idx = torch.argsort(distances)
	selected_indices = sorted_idx[low_idx:high_idx]
	
	return selected_indices


def hard(
		phi: torch.Tensor,
		coreset_size: int,
		target_moment='mean',  # mean, geo_med, co_med
		**kwargs
) -> torch.Tensor:
	"""
	Hard Samples:
	Selects the hardest examples from the dataset i.e. farthest from the class centers E(x | y=j) for-all j \in Y
	
	ref:
	Beyond neural scaling laws: beating power law scaling via data pruning. NeuRips 2022.
	https://openreview.net/forum?id=UmvSlP-PyV

	@param phi:
	torch.Tensor of size N x D
	
	@param coreset_size: int
	size of the pruned dataset (coreset)
	
	@param target_moment: str
	
	@return hem_indices:
	torch.Tensor list of indices of selected samples
	"""
	
	class_center = compute_center(target_moment, phi)
	
	distances = torch.norm(phi - class_center, dim=1)
	hem_indices = torch.argsort(distances, descending=True)[:coreset_size]
	
	return hem_indices


def easy(
		phi: torch.Tensor,
		coreset_size: int,
		target_moment='mean',
		**kwargs
) -> torch.Tensor:
	"""
	Easy Examples
	--------------------------------
	Selects the easiest examples from the dataset i.e. closest to the class centers E(x | y=j) for-all j \in Y
	
	@param phi:
	torch.Tensor of size N x D
	
	@param coreset_size: int
	size of the pruned dataset (coreset)
	
	@param target_moment: str
	how to compute center (mean, geo_med, co_med)
	
	@return eem_indices:
	torch.Tensor list of indices of selected samples
	"""
	
	# Get the center of the class
	class_center = compute_center(target_moment, phi)
	
	distances = torch.norm(phi - class_center, dim=1)
	eem_indices = torch.argsort(distances, descending=False)[:coreset_size]
	
	return eem_indices


def geo_med_easy(
		phi: torch.Tensor,
		coreset_size: int,
		**kwargs
):
	"""
	Easy around GM
	"""
	return easy(
		phi=phi,
		coreset_size=coreset_size,
		target_moment='geo_med',
		**kwargs
	)


def uniform(
		phi: torch.Tensor,
		coreset_size: int,
		**kwargs
) -> torch.Tensor:
	"""
	Random Sampling
	--------------------------------
	@param phi:
	torch.Tensor of size N x D
	
	@param coreset_size: int
	size of the pruned dataset (coreset)

	@return rnd_indices:
	torch.Tensor list of selected samples
	
	"""
	
	# Random Sampling
	rnd_indices = torch.randperm(phi.size(0))[:coreset_size]
	
	return rnd_indices


# ===============================================================
# Herding
# ===============================================================

def herding(
		phi: torch.Tensor,
		coreset_size: int,
		dist_measure: str = 'cosine',  # L2, cosine
		init: str = 'random',  # mean , random , target_moment
		target_moment: str = 'mean',  # mean, geo_med, co_med
		**kwargs
) -> torch.Tensor:
	"""

	Herding on Hypersphere with PyTorch :
	==========================================

	@param phi:
	torch.Tensor of size N x D

	@param coreset_size: int
	size of the pruned dataset (coreset)

	@param dist_measure: str
	Distance Measure to use for Herding (L2, cosine)

	@param init: str
	Initialization for Herding (mean, random, target_moment)

	# @param n_init: int
	# number of random initializations

	@param target_moment: str
	robustness parameter for pruning algorithms. (mean, geo_med, co_med)

	"""
	
	# Set the Moment Matching Objective for Herding
	class_center = compute_center(target_moment=target_moment, phi=phi)
	
	# Normalize if cosine distance is used
	if dist_measure == 'cosine':
		phi = torch.nn.functional.normalize(phi, p=2, dim=1)
		class_center = torch.nn.functional.normalize(class_center, p=2, dim=0)
	
	# initialize the direction to explore
	if init == 'mean':
		w_t = torch.mean(phi, dim=0)
	
	elif init == 'random':
		w_t = phi[torch.randperm(phi.size(0))[:1]][0]
	
	elif init == 'target_moment':
		w_t = class_center
	
	elif init == 'farthest':
		if dist_measure == 'cosine':
			# initialize with the farthest sample from the class center
			s = torch.matmul(phi, class_center)
			w_t = class_center - phi[torch.argmin(s)]
		
		else:
			# initialize with the farthest sample from the class center
			d = torch.norm(phi - class_center, dim=1)
			w_t = phi[torch.argmax(d)]
	
	else:
		raise NotImplementedError
	
	kh_indices = []
	while len(kh_indices) < coreset_size:
		# compute scores
		if dist_measure == 'cosine':
			scores = torch.matmul(phi, w_t)
			indices = scores.argsort(descending=True)
		
		elif dist_measure == 'L2':
			scores = torch.norm(phi - w_t, dim=1)
			indices = scores.argsort(descending=False)
		
		else:
			raise NotImplementedError
		
		# perform updates
		new_ind = next((idx.item() for idx in indices if idx.item() not in kh_indices), None)
		w_t += class_center - phi[new_ind]
		
		kh_indices.append(new_ind)
	
	return kh_indices


def geo_med_herding(
		phi: torch.Tensor,
		coreset_size: int,
		**kwargs
) -> torch.Tensor:
	"""
	Robust Herding on Hypersphere with PyTorch
	"""
	return herding(
		phi=phi,
		coreset_size=coreset_size,
		dist_measure='cosine',
		init='random',
		target_moment='geo_med',
		**kwargs
	)


def trimmed_herding(
		phi: torch.Tensor,
		coreset_size: int,
		filtered_fraction: float = 1,
		**kwargs
) -> torch.Tensor:
	"""
	Robust Herding on Hypersphere with PyTorch
	"""
	if filtered_fraction == 1:
		# reduces to geo_med_herding
		return herding(
			phi=phi,
			coreset_size=coreset_size,
			dist_measure='cosine',
			init='random',
			target_moment='mean',
			**kwargs
		)
	else:
		# get trimmed indices
		filter_size = max(coreset_size, round(filtered_fraction * phi.size(0)))
		robust_indices = easy(
			phi=phi,
			coreset_size=filter_size,
			target_moment='geo_med',
			**kwargs
		)
		herding_sub_ix = herding(
			phi=phi[robust_indices],
			coreset_size=coreset_size,
			dist_measure='cosine',
			init='random',
			target_moment='mean',
		)
		
		return robust_indices[herding_sub_ix]


def trimmed_geo_med_herding(
		phi: torch.Tensor,
		coreset_size: int,
		filtered_fraction: float = 1,
		**kwargs
) -> torch.Tensor:
	"""
	Robust Herding on Hypersphere with PyTorch
	"""
	if filtered_fraction == 1:
		# reduces to geo_med_herding
		return herding(
			phi=phi,
			coreset_size=coreset_size,
			dist_measure='cosine',
			init='random',
			target_moment='geo_med',
			**kwargs
		)
	else:
		# get trimmed indices
		filter_size = max(coreset_size, round(filtered_fraction * phi.size(0)))
		robust_indices = easy(
			phi=phi,
			coreset_size=filter_size,
			target_moment='geo_med',
			**kwargs
		)
		herding_sub_ix = herding(
			phi=phi[robust_indices],
			coreset_size=coreset_size,
			dist_measure='cosine',
			init='random',
			target_moment='geo_med',
		)
		
		return robust_indices[herding_sub_ix]


class PrunedDataset:
	"""
	Pruned Dataset
	--------------
	Wrapper around the pruned dataset.
	"""
	
	def __init__(
			self,
			dataset: Dataset,
			proxy_nw_arch: str,
			proxy_model: torch.nn.Module,
			# ------ Pruning Parameters ------
			pruning_algorithm: str,  # uniform, easy, hard, moderate, herding
			coreset_size: int,  # for sampling
			encoding_batch_size: int = 16,  # Batch Size for Encoding pre-pruning
			pruning_batch_size: int = None,  # Batch Size for Pruning - if None then batched pruning is disabled
			filtered_fraction: float = 1,  # robustness parameter ~ fraction to keep if filtering
			# --- corruption parameters: for online corruption ---
			corruption_config: dict = None,
	):
		"""

		@param dataset: Dataset

		@param proxy_model: torch.nn.Module

		@param pruning_algorithm: str
		Pruning Algorithm to use from:
			- uniform
			- easy,
			- hard,
			- moderate,
			- herding


		@param coreset_size: int

		@param target_moment: str
		robustness parameter for pruning algorithms. (mean, geo_med, co_med)

		@param robust_filtering: bool
		robustness parameter for pruning algorithms. (True, False) If True, apply robust filtering i.e. keep
		only the samples closest to the robust mean estimate.

		@param dist_measure: str
		Distance Measure to use for Herding (L2, cosine)

		@param init: str
		Initialization for Herding (mean, random, target_moment)

		@param encoding_batch_size: int

		@param pruning_batch_size: int
		Batch Size for Pruning - if None then batched pruning is disabled

		"""
		supported_pruning_algorithms = {
			"uniform": uniform,
			"easy": easy,
			"hard": hard,
			"moderate": moderate,
			"herding": herding,
			"geo_med_herding": geo_med_herding,
			# "trimmed_herding": trimmed_herding,
			"trimmed_geo_med_herding": trimmed_geo_med_herding,
			"geo_med_easy": geo_med_easy
		}
		self.sampling_function = supported_pruning_algorithms[pruning_algorithm]
		self.proxy_model = proxy_model
		
		# wrap around dataloader
		dataloader = DataLoader(
			dataset=dataset,
			batch_size=pruning_batch_size if pruning_batch_size is not None else len(dataset),
			num_workers=int(os.cpu_count() / 2) if os.cpu_count() > 2 else 1,
			shuffle=True
		)
		
		# calculate per class / per batch coreset size
		class_distribution = Counter(dataset.targets)
		print("CLASS DISTRIBUTION: ", class_distribution)
		total_count = sum(class_distribution.values())
		print("TOTAL SAMPLES: ", total_count)
		coreset_size_per_class = defaultdict(int)
		for label, count in class_distribution.items():
			coreset_size_per_class[label] = math.ceil(coreset_size * (count / total_count))
		print("CORESET SIZE PER CLASS: ", coreset_size_per_class)
		
		device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		proxy_model.to(device)
		proxy_model.eval()
		
		self.pruned_feat = []
		self.pruned_targets = []
		
		if corruption_config is not None:
			attack_location = corruption_config.get('attack_location', None)
			print(corruption_config)
		else:
			attack_location = None
		
		# ================================================
		# Pruning : iterate over the dataset and prune over each batch
		# ================================================
		with torch.no_grad():
			# iterate over the dataset and prune over each batch
			for batch_ix, batch in enumerate(dataloader):
				feat, targets = batch
				if attack_location == 'sample':
					feat, _ = corrupt_image(
						phi=feat,
						corruption_cfg=corruption_config
					)
				# -------------------------
				# prune each marginal
				# -------------------------
				feat = feat.to(device)
				unique_labels = torch.unique(targets)
				for label in tqdm(
						unique_labels,
						desc="Selecting {} Coreset: batch {}/{}".format(
							coreset_size, batch_ix + 1, len(dataloader)
						)
				):
					# all the indices with y = label in the batch
					label_indices = torch.where(targets == label)[0]
					# get embeddings from the proxy model
					feat_given_label = feat[label_indices]
					embeddings = self.get_embeddings(
						features=feat_given_label,
						nw_arch=proxy_nw_arch,
						batch_size=encoding_batch_size
					)
					
					if attack_location == 'feature':
						embeddings, _ = corrupt_image(
							phi=embeddings,
							corruption_cfg=corruption_config
						)
					
					# fraction of total samples y = label in dataset
					frac_label_indices = label_indices.size(0) / class_distribution[label.item()]
					# how many samples of this class to select from the batch
					batch_coreset_size = max(1, math.ceil(frac_label_indices * coreset_size_per_class[label.item()]))
					
					selected_indices = self.sampling_function(
						phi=embeddings,
						coreset_size=batch_coreset_size,
						# target_moment=target_moment,
						# robust_filtering=robust_filtering,
						filtered_fraction=filtered_fraction,
						# dist_measure=dist_measure,
						# init=init,
					)
					self.pruned_feat.append(feat_given_label[selected_indices].cpu())
					self.pruned_targets.append(
						label.cpu() * torch.ones(len(selected_indices), dtype=targets.dtype).cpu()
					)
		
		self.pruned_feat = torch.cat(self.pruned_feat, dim=0)
		self.pruned_targets = torch.cat(self.pruned_targets, dim=0)
		assert self.pruned_feat.size(0) == self.pruned_targets.size(0)
		print("Pruned Dataset Size: feat: {}, labels: {} \n".
		      format(self.pruned_feat.size(0), self.pruned_targets.size(0)))
	
	def __len__(self):
		"""
		Length of the Dataset
		"""
		return self.pruned_feat.size(0)
	
	def __getitem__(self, idx):
		"""
		Get Item from the Dataset
		"""
		return self.pruned_feat[idx], self.pruned_targets[idx]
	
	def get_embeddings(
			self,
			features: torch.Tensor,
			nw_arch: str,
			batch_size=32
	) -> torch.Tensor:
		"""
		Get Embeddings from the Model
		"""
		self.proxy_model.eval()
		
		with torch.no_grad():
			embeddings = []
			for i in range(0, len(features), batch_size):
				batch_features = features[i:i + batch_size]
				if nw_arch in ['clip-ViT-B/16', 'clip-ViT-B/32']:
					batch_embeddings = self.proxy_model.encode_image(batch_features)
				elif nw_arch in ['tv-resnet18', 'tv-resnet50']:
					batch_embeddings = self.proxy_model(batch_features).flatten(start_dim=1)
				embeddings.append(batch_embeddings)
			
			embeddings = torch.cat(embeddings, dim=0)
		
		return embeddings
