import arff
import torch
import pickle
import numpy as np
import networkx as nx
import random
from networkx.algorithms.approximation import maximum_independent_set

def get_graph(graph_type, num_arms, edge_prob=0.75):
	graph = np.zeros((num_arms, num_arms))
	if graph_type == 'bandit':
		for i in range(num_arms):
			graph[i, i] = 1
	elif graph_type == 'ij':
		for i in range(num_arms):
			for j in range(i, num_arms):
				graph[i, j] = 1
	elif graph_type == 'robs_cops':
		for i in range(num_arms):
			for j in range(num_arms):
				if (j != i):
					graph[i, j] = 1
	elif graph_type == 'full_info':
		for i in range(num_arms):
			for j in range(num_arms):
				graph[i, j] = 1
	elif graph_type == 'random':
		for i in range(num_arms):
			for j in range(num_arms):
				if (j == i):
					graph[i, j] = 1
				else:
					rand_num = random.random()
					if (rand_num < edge_prob):
						graph[i, j] = 1
	else:
		raise ValueError
	return graph

def normalization(X):
	mean = np.mean(X, axis=0)
	std = np.std(X, axis=0)

	normalized_X = (X - mean) / std
	return normalized_X

def calc_ind(graph):
	k = graph.shape[0]
	g = nx.Graph()
	g.add_nodes_from(range(k))
	g.remove_edges_from(nx.selfloop_edges(g))
	for i in range(k):
		for j in range(k):
			if graph[i, j] > 0:
				g.add_edge(i, j)

	ind_set = maximum_independent_set(g)
	return ind_set

def load_scene(file_path='./scene.arff'):
	data = arff.load(open(file_path, 'r'))['data']
	n = len(data)
	data_np = np.array(data)
	np.random.shuffle(data_np)
	X = data_np[:, :294].astype(np.float32)
	y = np.zeros([n, 6])
	for i in range(n):
		for j in range(294, 300):
			assert data_np[i, j] == 'TRUE' or data_np[i, j] == 'FALSE'
			if data_np[i, j] == 'TRUE':
				y[i, j-294] = 1
	return X, y

def load_rcv1(file_path='./rcv1/data_50k.gt'):
	f = open(file_path, 'rb')
	data = pickle.load(f)
	X, y = data['X'].astype(np.float32), data['y'].astype(np.float32)
	return X, y

def load_rcv1_full(file_path='./rcv1/data_full_50k.gt'):
	f = open(file_path, 'rb')
	data = pickle.load(f)
	X, y = data['X'].toarray().astype(np.float32), data['y'].astype(np.float32)
	return X, y

def load_inventory(file_path='./inventory/data_inventory.gt'):
	f = open(file_path, 'rb')
	data = pickle.load(f)
	X, y = data['X'].astype(np.float32), data['d'].astype(np.float32)
	return X, y

def load_media(file_path='./dataset/data_media.gt'):
	# pca -> 120 features, 101 labels
	f = open(file_path, 'rb')
	data = pickle.load(f)
	X, y = data['X'].astype(np.float32), data['y'].astype(np.float32)
	return X, y

if __name__ == "__main__":
	X, y = load_media()
	print(X.shape, y.shape)


