import arff
import torch
import pickle
import numpy as np
import networkx as nx
import random
from collections import defaultdict
from networkx.algorithms.approximation import maximum_independent_set

def get_graph(graph_type, num_arms, edge_prob=0.75):
	graph = np.zeros((num_arms, num_arms))
	if graph_type == 'bandit':
		for i in range(num_arms):
			graph[i, i] = 1
	elif graph_type == 'ij':
		for i in range(num_arms):
			for j in range(i, num_arms):
				graph[i, j] = 1
	elif graph_type == 'robs_cops':
		for i in range(num_arms):
			for j in range(num_arms):
				if (j != i):
					graph[i, j] = 1
	elif graph_type == 'full_info':
		for i in range(num_arms):
			for j in range(num_arms):
				graph[i, j] = 1
	elif graph_type == 'random':
		for i in range(num_arms):
			for j in range(num_arms):
				if (j == i):
					graph[i, j] = 1
				else:
					rand_num = random.random()
					if (rand_num < edge_prob):
						graph[i, j] = 1
	elif graph_type == 'undirected':
		for i in range(num_arms):
			for j in range(i+1):
				if (j == i):
					graph[i, j] = 1
				else:
					rand_num = random.random()
					if (rand_num < edge_prob):
						graph[i, j] = 1
						graph[j, i] = 1
	else:
		raise ValueError
	return graph

def normalization(X):
	mean = np.mean(X, axis=0)
	std = np.std(X, axis=0)

	normalized_X = (X - mean) / std
	return normalized_X

def calc_ind(graph):
	k = graph.shape[0]
	g = nx.Graph()
	g.add_nodes_from(range(k))
	g.remove_edges_from(nx.selfloop_edges(g))
	for i in range(k):
		for j in range(k):
			if graph[i, j] > 0:
				g.add_edge(i, j)

	ind_set = maximum_independent_set(g)
	return ind_set

def calc_ind_rand(graph):
	k = graph.shape[0]
	idx = np.arange(k)
	max_ind = 0
	for i in range(1000000):
		permutation = np.random.permutation(k)
		new_idx = idx[permutation]
		is_cover = defaultdict(int)
		ind = 0
		for j in range(k):
			if is_cover[j] == 0:
				is_cover[j] = 1
				ind += 1
				for l in range(k):
					if graph[j, l]:
						is_cover[l] = 1
		max_ind = max(max_ind, ind)

	return max_ind

def load_rcv1(file_path='./rcv1/data_50k.gt'):
	f = open(file_path, 'rb')
	data = pickle.load(f)
	X, y = data['X'].astype(np.float32), data['y'].astype(np.float32)
	return X, y

def load_rcv1_full(file_path='./rcv1/data_full_50k.gt'):
	f = open(file_path, 'rb')
	data = pickle.load(f)
	X, y = data['X'].toarray().astype(np.float32), data['y'].astype(np.float32)
	return X, y

def load_inventory(file_path='./inventory/data_inventory.gt'):
	f = open(file_path, 'rb')
	data = pickle.load(f)
	X, y = data['X'].astype(np.float32), data['d'].astype(np.float32)
	return X, y

def load_media(file_path='./dataset/data_media.gt'):
	f = open(file_path, 'rb')
	data = pickle.load(f)
	X, y = data['X'].astype(np.float32), data['y'].astype(np.float32)
	return X, y

if __name__ == "__main__":
	X, y = load_media()
	print(X.shape, y.shape)


