import copy
import pickle

import mkl

mkl.get_max_threads()

import random
import time
import numpy as np
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.cluster import KMeans
from collections import Counter, namedtuple
import torch
import faiss


def choose_device(cuda=False):
	if cuda:
		device = torch.device("cuda:0")
	else:
		device = torch.device("cpu")
	return device


def setup_seed(seed):
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)
	np.random.seed(seed)
	random.seed(seed)
	torch.backends.cudnn.deterministic = True


def torch_rbf_kernel(x1, x2, gamma):
	'''
	使用 torch 实现的 rbf_kernel 函数

	:param gamma: 高斯核带宽
	'''
	X12norm = torch.sum(x1 ** 2, 1, keepdim=True) - 2 * \
			  x1 @ x2.T + torch.sum(x2 ** 2, 1, keepdim=True).T
	return torch.exp(-X12norm * gamma)


class StaSpec():
	
	def __init__(self, gamma=0.1, cuda=False, info=None):
		'''
		统计规约通类
		:param gamma: 高斯核的带宽
		:param cuda: 是否使用 cuda
		:var z: 缩略集的样本
		:var beta: 缩略集的样本权重
		:var device: 使用 CPU 或 GPU，由参数 cuda 决定
		'''
		self.z = []
		self.beta = []
		# self.BETA_EPS = 0.01
		self.gamma = gamma
		self.device = choose_device(cuda=cuda)
		# self.num_points = 0
		self.cuda = cuda
		self.info = info
	
	def preprocess(self, X, weight):
		weight = weight.reshape(1, -1)
		return (X * weight).astype(np.float32)
	
	def fit(self, X, K, step_size, steps):
		'''
		为数据集构造缩略集
		
		:param X: 数据集 (type: torch.tensor or np)
		:param K: 缩略集大小
		:param step_size: 迭代优化中梯度下降的步长
		:param steps: 迭代优化轮数
		:return: 返回计算得到的大小为 K 缩略集
		'''
		alpha = None
		self.num_points = X.shape[0]
		
		# Initialize Z by Clustering
		self._init_z_by_faiss(X, K)
		self._update_beta(alpha, X)
		# Alternating optimize Z and beta
		for i in range(steps):
			self._update_z(alpha, X, step_size)
			self._update_beta(alpha, X)
	
	def _init_z_by_faiss(self, X, K):
		'''
		通过 faiss 聚类来初始化 z

		:param X: 数据集
		:param K: 缩略集大小
		:return: 通过聚类初始化得的缩略集 z
		'''
		if X.shape[0] == 1 and K == 1: 
			self.z = X
			return 
		numDim = X.shape[1]
		# kmeans = faiss.Kmeans(numDim, k, niter=100, verbose=False, gpu=self.cuda)
		kmeans = faiss.Kmeans(numDim, K, niter=100, verbose=False)
		kmeans.train(X)
		center = torch.from_numpy(kmeans.centroids)
		self.z = center
	
	def _update_beta(self, alpha, X):
		'''
		固定 X，通过求闭式解来优化 beta

		:param alpha: 暂未使用
		:param X: 数据集
		:return: 计算得到的权重 beta
		'''
		# Z = np.array(self.z)
		Z = self.z
		
		if not torch.is_tensor(Z):
			Z = torch.from_numpy(Z)
		Z = Z.to(self.device)
		
		if not torch.is_tensor(X):
			X = torch.from_numpy(X)
		X = X.to(self.device)
		K_z = torch_rbf_kernel(Z, Z, gamma=self.gamma).to(self.device)
		K_zx = torch_rbf_kernel(Z, X, gamma=self.gamma).to(self.device)
		
		# beta = torch.sum(torch.pinverse(K_z) @ K_zx,
		# 				 dim=1) / X.shape[0]  # alpha: 归一化参数
		beta = torch.sum(torch.linalg.inv(K_z + torch.eye(K_z.shape[0]).to(self.device) * 1e-5) @ K_zx,
						 dim=1) / X.shape[0]  # alpha: 归一化参数
		# self.beta = list(beta)
		self.beta = beta
	
	def _update_z(self, alpha, X, step_size):
		'''
		固定 beta，通过梯度下降来优化 z

		:param alpha: 归一化参数
		:param X: 数据集
		:param step_size: 梯度下降的步长
		:return: 经过一步优化后的缩略集 z
		'''
		gamma = self.gamma
		
		Z = self.z
		beta = self.beta
		# Z = np.array(self.z.detach().cpu())
		# beta = np.array(self.beta.detach().cpu())
		
		if not torch.is_tensor(Z):
			Z = torch.from_numpy(Z)
		Z = Z.to(self.device)
		
		if not torch.is_tensor(beta):
			beta = torch.from_numpy(beta)
		beta = beta.to(self.device)
		
		if not torch.is_tensor(X):
			X = torch.from_numpy(X)
		X = X.to(self.device)
		
		grad_Z = torch.zeros_like(Z)
		
		for i in range(Z.shape[0]):
			z_i = Z[i, :].reshape(1, -1)
			term_1 = (beta * torch_rbf_kernel(z_i, Z, gamma)) @ (z_i - Z)
			if alpha is not None:
				term_2 = -2 * (alpha * torch_rbf_kernel(z_i, X, gamma)) @ (z_i - X)
			else:
				term_2 = -2 * (torch_rbf_kernel(z_i, X, gamma) / self.num_points) @ (z_i - X)
			grad_Z[i, :] = -2 * gamma * beta[i] * (term_1 + term_2)
		Z = Z - step_size * grad_Z
		# self.z = list(Z)
		self.z = Z
	
	def save(self, savepath='saved_spec.pkl', verbose=False):
		'''
		将缩略集保存至文件

		:param savepath: 文件路径
		:param verbose: 是否要输出提示信息
		'''
		# rkme_to_save = namedtuple("Person", "name age")
		with open(savepath, 'wb') as fw:
			rkme_to_save = copy.deepcopy(self.__dict__)
			pickle.dump(rkme_to_save, fw)
			if verbose:
				print('RKME save attr: {} in {}'.format(rkme_to_save.keys(), savepath))
	
	def load(self, filepath='saved_spec.pkl', verbose=False):
		'''
		从文件路径读取缩略集

		:param filepath: 文件路径
		:param verbose: 是否要输出提示信息
		'''
		fr = open(filepath, 'rb')
		rkme_load = pickle.load(fr)
		fr.close()
		# self = rkme
		for d in self.__dir__():
			if d in rkme_load.keys():
				if verbose:
					print('Loading attr: {} from {}, {}'.format(d, filepath, rkme_load[d]))
				setattr(self, d, rkme_load[d])


class RKME_Searcher():
	
	def __init__(self, list_of_Phi):
		'''
		统计规约查搜类
		:param list_of_Phi: 包含所有候选学件的统计规约(StaSpec类)的列表, 作为查搜的范围
		:var Phi: 保存了参数 list_of_Phi
		:var c: 学件的个数
		:var gamma: 高斯核的带宽, 注意list_of_Phi中的gamma都需要统一
		'''
		self.Phi = list_of_Phi
		self.c = len(list_of_Phi)
		self.gamma = list_of_Phi[0].gamma
	
	def cal_dist(self, Phi_t, cuda=False, omit_term1=False):
		'''
		计算用户任务的统计规约和所有候选学件的统计规约的距离
		:param Phi_t: 用户任务的统计规约(StaSpec类), 其gamma也需要和所有候选学件统计规约的gamma统一
		:param cuda: 是否使用 cuda
		:return: 一个列表, 里面包含了用户统计规约和所有候选学件的统计规约的距离
				 如果距离计算失败, 则为inf, 否则为一个实数
		'''
		dist_list = []
		for i in range(self.c):
			Phi_s = self.Phi[i]
			# dist2 = self.Phi[i].dist_between_reduced_sets(Phi_t)
			try:
				dist = MMD(Phi_t.z, Phi_s.z, beta1=Phi_t.beta,
						   beta2=Phi_s.beta, gamma=self.gamma,
						   cuda=cuda, omit_term1=omit_term1)
				dist_list.append(dist.cpu().item())
			except:
				dist_list.append(float("inf"))
		return dist_list


def MMD(x1, x2, beta1=None, beta2=None, omit_term1=False, gamma=1.0, cuda=False):
	'''
		计算两个KME之间的距离(MMD)
		:param x1: 第一个缩略集的样本
		:param x2: 第二个缩略集的样本
		:param beta1: 第一个缩略集的权重; 如果缺失, 则为等权重
		:param beta2: 第二个缩略集的权重; 如果缺失, 则为等权重
		:param cuda: 是否使用 cuda
		:param omit_term1: 暂未使用此参数, 请设为False
		:param gamma: 高斯核的带宽
		:return: 一个实数, 表示距离
	'''
	device = choose_device(cuda=cuda)
	# || sum_i alpha1_i k(x1_i, \cdot) - sum_i alpha2_i k(x2_i, \cdot) ||_H^2
	if beta1 is None or beta1.abs().sum() < 1e-5:
		beta1 = torch.ones(x1.shape[0]) / x1.shape[0]
	if beta2 is None or beta2.abs().sum() < 1e-5:
		beta2 = torch.ones(x2.shape[0]) / x2.shape[0]

	if not torch.is_tensor(x1):
		x1 = torch.from_numpy(x1)
	if not torch.is_tensor(x2):
		x2 = torch.from_numpy(x2)
	if not torch.is_tensor(beta1):
		beta1 = torch.from_numpy(beta1)
	if not torch.is_tensor(beta2):
		beta2 = torch.from_numpy(beta2)
	
	beta1 = beta1.reshape(1, -1)
	beta2 = beta2.reshape(1, -1)
	
	assert beta1.size(1) == x1.size(0) and beta2.size(1) == x2.size(0)
	assert x1.size(1) == x2.size(1)
	
	# x1 = torch.tensor(x1, dtype=torch.float32)
	# x2 = torch.tensor(x2, dtype=torch.float32)
	# alpha1 = torch.tensor(alpha1, dtype=torch.float32)
	# alpha2 = torch.tensor(alpha2, dtype=torch.float32)
	
	x1 = x1.double()
	x2 = x2.double()
	beta1 = beta1.double()
	beta2 = beta2.double()
	x1 = x1.to(device)
	x2 = x2.to(device)
	beta1 = beta1.to(device)
	beta2 = beta2.to(device)
	
	if omit_term1:
		term1 = 0
	else:
		term1 = torch.sum(torch_rbf_kernel(
			x1, x1, gamma) * (beta1.T @ beta1))
	term2 = torch.sum(torch_rbf_kernel(x1, x2, gamma) * (beta1.T @ beta2))
	term3 = torch.sum(torch_rbf_kernel(x2, x2, gamma) * (beta2.T @ beta2))
	return (term1 - 2 * term2 + term3).detach().cpu()


if __name__ == '__main__':
	Num = 60000
	dim = 80
	spec_num = 100
	
	setup_seed(1234)
	
	# IMPORTANT: Input should be float 32
	X = np.random.randn(Num, dim).astype(np.float32)
	
	setup_seed(1234)
	t0 = time.time()
	rkme2 = StaSpec(gamma=0.01, cuda=True)
	rkme2.fit(X, spec_num, 0.01, 3)
	specification2 = np.array(rkme2.z.detach().cpu())
	print('RKME_fast finished in {}s'.format(time.time() - t0))
	print('MMD Disc: {}'.format(
		MMD(X, specification2, beta2=rkme2.beta.detach().cpu(), cuda=False, omit_term1=True)))
	print('RKME+MMD finished in {}s'.format(time.time() - t0))
