from scipy import io
from utils import *
import numpy as np 
import matplotlib.pyplot as plt 
import matplotlib.cm as cm 
from sklearn.kernel_approximation import RBFSampler

# Synthetic dataset - counterexample for Frobenius norm
def synthetic_matrix(n, k, noise=None, noise_factor=0.01):
	M = np.zeros((n + k, n + k))
	M[0:k, 0:k] = (n ** (3/2)) * np.identity(k)
	M[k:, k:] = 1
	if noise=="Gaussian":
		M += np.random.normal(0, scale=noise_factor, size=(n + k, n + k))
	elif noise=="Cauchy":
		M += noise_factor * dense_cauchy_matrix((n + k, n + k))
	
	# Shuffle columns of M so that they can be partitioned randomly,
	# out of fairness to greedy algorithm for Frobenius.
	np.random.shuffle(M.T)

	return M

def load_bcsstk13(filepath): 
	A = io.mmread(filepath).todense()
	A = np.asarray(A)
	return A

def load_image(filepath):
	return plt.imread(filepath).astype('float64')

def save_image(data, filepath):
	plt.imsave(filepath, data, cmap=cm.gray)

# filepath = isolet1+2+3+4.csv
def load_isolet(filepath):
	A = np.loadtxt(fname=filepath, delimiter=',')

	# Remove the last column, as it is the label
	num_cols = A.shape[-1] - 1
	A = A[:, :num_cols]
	return A

# Loads the transpose of isolet, as a list of 5 arrays
# (the columns are split between 5 servers)
def load_isolet_transpose(filepath):
	A = load_isolet(filepath)
	A = A.T
	_, num_cols = A.shape

	# Cutoffs for splitting.
	cutoff1 = num_cols//5
	cutoff2 = (2 * num_cols)//5
	cutoff3 = (3 * num_cols)//5
	cutoff4 = (4 * num_cols)//5

	# Split into parts.
	column_partition_list = []
	column_partition_list.append(A[:, :cutoff1])
	column_partition_list.append(A[:, cutoff1:cutoff2])
	column_partition_list.append(A[:, cutoff2:cutoff3])
	column_partition_list.append(A[:, cutoff3:cutoff4])
	column_partition_list.append(A[:, cutoff4:])

	return A, column_partition_list


def preprocess_forest_cover(filepath, num_sample, num_random_fourier_features, save_filepath): 
	lines = open(filepath).readlines()
	indices = np.random.choice(len(lines), num_sample)
	A = []
	for index in indices: 
		line = lines[index] 
		line = line.split(',')
		formatted_line = [float(line[i]) for i in range(54)]
		A.append(formatted_line) 
	A = np.array(A)
	rbf_feature = RBFSampler(n_components=num_random_fourier_features)
	A_features = rbf_feature.fit_transform(A)

	# checking calculations 
	Z = rbf_feature.random_weights_
	b = rbf_feature.random_offset_
	A_hand_calculation = np.sqrt(2) * np.cos(A @ Z + b) / np.sqrt(num_random_fourier_features) 
	if not np.sum(np.abs(A_features - A_hand_calculation) < 1e-10) == np.prod(A_features.shape): 
		print("Wrong Calculations of Random Fourier Features")
		exit()

	np.save(save_filepath, A_features.T)

def load_forest_cover(numpy_filepath): 
	A = np.load(numpy_filepath)
	_, num_cols = A.shape

	# Cutoffs for splitting.
	cutoff1 = num_cols//5
	cutoff2 = (2 * num_cols)//5
	cutoff3 = (3 * num_cols)//5
	cutoff4 = (4 * num_cols)//5

	# Split into parts.
	Ais = []
	Ais.append(A[:, :cutoff1])
	Ais.append(A[:, cutoff1:cutoff2])
	Ais.append(A[:, cutoff2:cutoff3])
	Ais.append(A[:, cutoff3:cutoff4])
	Ais.append(A[:, cutoff4:])
	
	return A, Ais

# Transposed to make problem interesting.
def load_additional_dataset(dataset_name):
	if dataset_name == "gastro_lesions":
		# Remove first 3 rows, which correspond to text/categorical data
		# Take the transpose, because originally the rows are features.
		file_path = "Additional Datasets/gastroenterology_dataset/data.txt"
		A = np.loadtxt(file_path, delimiter=',', dtype="str")
		A = A[3:, :]
		A = A.astype(np.float64)
		A = A.T
		return A
	if dataset_name == "madelon":
		DIR_NAME = "Additional Datasets/MADELON/"
		test_file = "madelon_test.data"
		train_file = "madelon_train.data"
		valid_file = "madelon_valid.data"
		
		# Load matrices
		test_matrix = np.loadtxt(DIR_NAME + test_file)
		train_matrix = np.loadtxt(DIR_NAME + train_file)
		valid_matrix = np.loadtxt(DIR_NAME + valid_file)

		result_matrix = np.vstack((test_matrix, train_matrix, valid_matrix))
		return result_matrix.T
	
	if dataset_name == "secom":
		file_path = "Additional Datasets/secom/secom.data"
		A = np.loadtxt(file_path)
		A = np.nan_to_num(A)
		A = A.T
		return A
	
	odds_datasets = ["mnist", "musk", "speech"]
	if dataset_name in odds_datasets:
		file_path = "Additional Datasets/ODDS_datasets/" + dataset_name + ".mat"
		mat_file = io.loadmat(file_path)

		A = mat_file["X"]
		A = A.T
		return A

	raise NotImplementedError


if __name__ == '__main__':
	# A = load_isolet('isolet1+2+3+4.csv')
	# print(A.shape)
	# print(A)

	"""
	B = load_isolet('isolet1+2+3+4.csv')
	print(B.shape)
	print(np.count_nonzero(B))
	"""
	A = load_additional_dataset("speech")
	print(A.shape)
	print(np.sum(np.abs(A)))

	# preprocess_forest_cover("covtype.data", num_sample=3000, num_random_fourier_features=500, save_filepath="forest_cover_500x3000.npy")
