import numpy  
from sklearn import preprocessing
import csv
import numpy as np


def relabel(y):
	d = {}
	count = 0
	for i in y:
		if i not in d:
			d[i] = count
			count += 1
	new_y = [d[i] for i in y]
	return numpy.array(new_y), count


def norm_data(X):
	scaler = preprocessing.StandardScaler().fit(X)
	X = scaler.transform(X)
	return X



def load_data(dataset,norm=False):
	print("=========Loading "+dataset)

	if dataset == 'a1':
		X = []
		y = []
		with open('data/a1/a1.txt', mode ='r') as file:
			csvFile = csv.reader(file)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0].split()
				X.append(lines)

		with open('data/a1/a1-ga.pa', mode ='r') as file:
			csvFile = csv.reader(file)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0]
				y.append(lines)
			y, count = relabel(y)

		if norm == True:
			return norm_data(np.array(X).astype('float')),np.array(y),count
		else:
			return np.array(X).astype('float'),np.array(y),count


	if dataset == 'a2':
		X = []
		y = []
		with open('data/a2/a2.txt', mode ='r') as file:
			csvFile = csv.reader(file)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0].split()
				X.append(lines)

		with open('data/a2/a2-ga.pa', mode ='r') as file:
			csvFile = csv.reader(file)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0]
				y.append(lines)
			y, count = relabel(y)

		if norm == True:
			return norm_data(np.array(X).astype('float')),np.array(y),count
		else:
			return np.array(X).astype('float'),np.array(y),count



	if dataset == 'a3':
		X = []
		y = []
		with open('data/a3/a3.txt', mode ='r') as file:
			csvFile = csv.reader(file)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0].split()
				X.append(lines)

		with open('data/a3/a3-ga.pa', mode ='r') as file:
			csvFile = csv.reader(file)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0]
				y.append(lines)
			y, count = relabel(y)

		if norm == True:
			return norm_data(np.array(X).astype('float')),np.array(y),count
		else:
			return np.array(X).astype('float'),np.array(y),count



	if dataset == 'unbalance':
		X = []
		y = []
		with open('data/unbalance/unbalance.txt', mode ='r') as file:
			csvFile = csv.reader(file)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0].split()
				X.append(lines)

		with open('data/unbalance/unbalance-gt.pa', mode ='r') as file:
			csvFile = csv.reader(file)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0]
				y.append(lines)
			y, count = relabel(y)

		if norm == True:
			return norm_data(np.array(X).astype('float')),np.array(y),count
		else:
			return np.array(X).astype('float'),np.array(y),count


	if dataset == 's1':
		X = []
		y = []
		with open('data/s1/s1.txt', mode ='r') as file:
			csvFile = csv.reader(file)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0].split()
				X.append(lines)

		with open('data/s1/s1-label.pa', mode ='r') as file:
			csvFile = csv.reader(file)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0]
				y.append(lines)
			y, count = relabel(y)

		if norm == True:
			return norm_data(np.array(X).astype('float')),np.array(y),count
		else:
			return np.array(X).astype('float'),np.array(y),count


	if dataset == 's2':
		X = []
		y = []
		with open('data/s2/s2.txt', mode ='r') as file:
			csvFile = csv.reader(file)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0].split()
				X.append(lines)

		with open('data/s2/s2-label.pa', mode ='r') as file:
			csvFile = csv.reader(file)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0]
				y.append(lines)
			y, count = relabel(y)

		if norm == True:
			return norm_data(np.array(X).astype('float')),np.array(y),count
		else:
			return np.array(X).astype('float'),np.array(y),count


	if dataset == 's3':
		X = []
		y = []
		with open('data/s3/s3.txt', mode ='r') as file:
			csvFile = csv.reader(file)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0].split()
				X.append(lines)

		with open('data/s3/s3-label.pa', mode ='r') as file:
			csvFile = csv.reader(file)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0]
				y.append(lines)
			y, count = relabel(y)

		if norm == True:
			return norm_data(np.array(X).astype('float')),np.array(y),count
		else:
			return np.array(X).astype('float'),np.array(y),count



	if dataset == 's4':
		X = []
		y = []
		with open('data/s4/s4.txt', mode ='r') as file:
			csvFile = csv.reader(file)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0].split()
				X.append(lines)

		with open('data/s4/s4-label.pa', mode ='r') as file:
			csvFile = csv.reader(file)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			next(csvFile)
			for lines in csvFile:
				if len(lines) == 0:
					continue
				lines = lines[0]
				y.append(lines)
			y, count = relabel(y)

		if norm == True:
			return norm_data(np.array(X).astype('float')),np.array(y),count
		else:
			return np.array(X).astype('float'),np.array(y),count



if __name__ == '__main__':
	datasets = ['a1','a2','a3','unbalance','s1','s2','s3','s4']

	for dataset in datasets:
		X,y,count = load_data(dataset,True)
		assert (len(X) == len(y))
		print(X)
		print(y)
		print(count)










