import csv
import os.path
import sys

# Check if a data row is a valid representation of a vertex
def is_valid_vertex(vertex):
	# Check for fields that must NOT be present
	for field in ["start","end"]:
		if field in vertex:
			return False
	return True

# Check if a data row is a valid representation of an edge
def is_valid_edge(edge):
	# check for fields that must be present
	for field in ["start","end"]:
		if field not in edge:
			return False
	return True

# Transform input csv into 2 csv files, one for valid vertices and one for valid edges
def preprocess(path_in, filename_in, path_out, filename_out, id_field):

	# Lists to store data
	vertices = []
	vertex_fields = []
	edges = []
	edge_fields = []
	invalid_row_count = 0

	# Read input csv
	with open(path_in + filename_in + ".csv", newline='', encoding='latin-1') as csvfile:
		csvreader = csv.reader(csvfile, delimiter=',', quotechar='"')
		headers = next(csvreader)
		for row in csvreader:	
			# Check if number of entries per row is correct
			if len(row) >= len(headers):
				# Read row
				data = {}
				for (index, header) in enumerate(headers):			
					if row[index] != "":
						data[header.replace("_","")] = row[index]
				# Check if row is a valid vertex
				if is_valid_vertex(data):
					vertices.append(data)
					for field in data:
						if field not in vertex_fields:
							vertex_fields.append(field)
				# Check if row is a valid edge 
				elif is_valid_edge(data):
					edges.append(data)
					for field in data:
						if field not in edge_fields:
							edge_fields.append(field)
				# Invalid row: Can not clearly be identified as vertex or edge
				else:
					invalid_row_count += 1
			# Invalid row: Wrong number of fields
			else:
				invalid_row_count += 1

	# Print some info about the dataset
	print("After reading file:")
	print("Nodes: %d" % len(vertices))
	print("Edges: %d" % len(edges))
	print("Invalid rows in .csv file: %d" % invalid_row_count)
	print()

	# Remap vertex ids
	old_to_new_vertex_id = {}
	for (i, vertex) in enumerate(vertices):
		old_to_new_vertex_id[vertex[id_field]] = str(i)
		vertex["id"] = str(i)
	valid_edges = []
	# Adapt start- and end-fields in edges to remapped vertex ids
	for (i, edge) in enumerate(edges):
		edge["id"] = i
		if edge["start"] in old_to_new_vertex_id and edge["end"] in old_to_new_vertex_id:
			edge["start"] = old_to_new_vertex_id[edge["start"]]
			edge["end"] = old_to_new_vertex_id[edge["end"]]
			valid_edges.append(edge)

	# Write new CSV file for vertices
	if not os.path.exists(path_out):
		os.makedirs(path_out)
		out_file = open(path_out + filename_out + "_vertices.csv", 'w')
	else:
		out_file = open(path_out + filename_out + "_vertices.csv", 'w')
	writer = csv.writer(out_file)
	writer.writerow(vertex_fields)
	for vertex in vertices:
		row = [(vertex[field] if field in vertex else "") for field in vertex_fields]
		writer.writerow(row)
	out_file.close()

	# Write new CSV file for edges 
	out_file = open(path_out + filename_out + "_edges.csv", 'w')
	writer = csv.writer(out_file)
	writer.writerow(edge_fields)
	for edge in valid_edges:
		row = [(edge[field] if field in edge else "") for field in edge_fields]
		writer.writerow(row)
	out_file.close()

if __name__ == '__main__':
	if len(sys.argv) < 2:
		print("Usage: preprocess_dataset.py <datasetname>")

	dataset = sys.argv[1]
	id_field = "entityid"	
	path_in = "raw_data/"
	path_out = "GraphGymPyG/datasets/%s/raw/" % dataset
	preprocess(path_in, dataset, path_out, dataset, id_field)

