# Dense network: (input_shape = (2048,))
# img(1024).cap(1024)->d(512)[relu]->drop(0.8)->d(256)[relu]->drop(0.8)->d(1)

# Conv2d network: (input_shape = (64,32,1))
# img(32,32,1).cap(32,32,1)->c2d(64,3,3)[relu]->drop(0.2)->c2d(64,3,3)[relu]->maxp(2, 2)->flat->d(256)[relu]->d(1)

# Conv2d with 2 channels: (input_shape = (32,32,2))
# img(32,32,1).{axis=2}cap(32,32,1)->c2d(32,3,3)[relu]->drop(0.2)->c2d(64,3,3)[relu]->maxp(2, 2)->flat->d(256)[relu]->d(1)

# from tensorflow.keras.layers import Embedding, SimpleRNN, LSTM, concatenate, Dense, Reshape, InputLayer, Activation, Dropout, BatchNormalization, Conv2D, MaxPool2D, Flatten

def parse_model_params(params_string):
	input_formula = None
	layers = []
	is_valid = True
	
	layer_strings = params_string.split('->')
	if len(layer_strings) == 0:
		print(f"Internal error: Somehow attempted to parse empty model parameters")
		return None, None, False
	
	start = 0
	# Check if the first layer specifies the input formula
	first_layer_string = layer_strings[0]
	if first_layer_string.startswith('img') or first_layer_string.startswith('cap') or first_layer_string.startswith('emb') or first_layer_string.startswith('('):
		# print(f"Parsing input formula: {first_layer_string}")
		input_formula, formula_len, is_valid = parse_input_formula(first_layer_string)
		start = 1
		if not is_valid:
			return None, None, False
		if formula_len != len(first_layer_string):
			print(f"Internal error: parse_input_formula returned a valid status but did not match the length of the formula string")
			return None, None, False
		# print(f"Parsed input formula: {input_formula}")
	
	for index, layer_string in enumerate(layer_strings[start:]):
		# print(f"Parsing layer {index}: '{layer_string}'")
		layer = []
		allows_activation = True
		pos = 0
		if layer_string.startswith('d(', pos):
			pos += 1
			neurons, tuple_len = parse_tuple(layer_string, pos)
			if neurons is None:
				print(f"Error parsing layer definition {index} '{layer_string}': Invalid tuple")
				is_valid = False
				break
			if len(neurons) != 1:
				print(f"Error parsing layer definition {index} '{layer_string}': Expected 1 element in tuple, got {len(neurons)}")
				is_valid = False
				break
			pos += tuple_len
			layer = ['Dense', neurons[0]]
		elif layer_string.startswith('reshape(', pos):
			pos += 7
			shape, tuple_len = parse_tuple(layer_string, pos)
			if shape is None:
				print(f"Error parsing layer definition {index} '{layer_string}': Invalid tuple")
				is_valid = False
				break
			if len(shape) < 1:
				print(f"Error parsing layer definition {index} '{layer_string}': Expected at least 1 element in tuple, got {len(shape)}")
				is_valid = False
				break
			pos += tuple_len
			layer = ['Reshape', shape]
			allows_activation = False
		elif layer_string.startswith('batchnorm', pos):
			pos += 9
			layer = ['BatchNormalization']
			allows_activation = True
		elif layer_string.startswith('c2d(', pos):
			pos += 3
			shape, tuple_len = parse_tuple(layer_string, pos)
			if shape is None:
				print(f"Error parsing layer definition {index} '{layer_string}': Invalid tuple")
				is_valid = False
				break
			if len(shape) < 2 or len(shape) > 3:
				print(f"Error parsing layer definition {index} '{layer_string}': Expected 2 or 3 elements in tuple, got {len(shape)}")
				is_valid = False
				break
			pos += tuple_len
			layer = ['Conv2D', shape]
		elif layer_string.startswith('drop(', pos):
			pos += 4
			rate, tuple_len = parse_tuple(layer_string, pos, float)
			if rate is None:
				print(f"Error parsing layer definition {index} '{layer_string}': Invalid tuple")
				is_valid = False
				break
			if len(rate) != 1:
				print(f"Error parsing layer definition {index} '{layer_string}': Expected 1 element in tuple, got {len(rate)}")
				is_valid = False
				break
			pos += tuple_len
			layer = ['Dropout', rate[0]]
			allows_activation = False
		elif layer_string.startswith('maxp2d(', pos):
			pos += 6
			shape, tuple_len = parse_tuple(layer_string, pos)
			if shape is None:
				print(f"Error parsing layer definition {index} '{layer_string}': Invalid tuple")
				is_valid = False
				break
			if len(shape) < 2 or len(shape) > 3:
				print(f"Error parsing layer definition {index} '{layer_string}': Expected 2 or 3 elements in tuple, got {len(shape)}")
				is_valid = False
				break
			pos += tuple_len
			layer = ['MaxPool2D', shape]
			allows_activation = False
		elif layer_string.startswith('flat', pos):
			pos += 4
			layer = ['Flatten']
			allows_activation = False
		
		# print(f"After initial layer definition: {layer} {pos}")
		
		# Check for activation functions
		if layer_string.startswith('[', pos):
			if not allows_activation:
				print(f"Error parsing layer definition {index} '{layer_string}': Layer type does not allow for activation function")
				is_valid = False
				break
			pos += 1
			parts = layer_string[pos:].partition(']')
			# print(f"parts = {parts}")
			if parts[1] != ']':
				print(f"Error parsing layer definition {index} '{layer_string}': No closing ']' in activation function")
				is_valid = False
				break
			if len(layer) == 0:
				layer = ['Activation', parts[0]]
			else:
				layer.append(parts[0])
			pos += len(parts[0]) + 1
		
		# print(f"pos = {pos}")
		
		if pos != len(layer_string):
			print(f"Error parsing layer definition {index} '{layer_string}': Extraneous symbols")
			is_valid = False
			break
		
		layers.append(layer)
	
	if is_valid and len(layers) < 1:
		print(f"Error parsing new model parameters: No layers given")
		is_valid = False
	return input_formula, layers, is_valid

def parse_input_formula(formula_string, start=0, level=0):
	result = []
	is_valid = True
	
	pos = start
	while pos < len(formula_string):
		if formula_string.startswith('img', pos):
			if len(result) % 2 != 0:
				print(f"Error parsing input formula: Expected operator, found 'img'")
				is_valid = False
				break
			img_part = ['img']
			pos += 3
			if pos < len(formula_string) and formula_string[pos] == '(':
				img_shape, img_shape_len = parse_tuple(formula_string, pos)
				if img_shape is None:
					is_valid = False
					break
				img_part.append(img_shape)
				pos += img_shape_len
			result.append(img_part)
		elif formula_string.startswith('cap', pos):
			if len(result) % 2 != 0:
				print(f"Error parsing input formula: Expected operator, found 'cap'")
				is_valid = False
				break
			cap_part = ['cap']
			pos += 3
			if pos < len(formula_string) and formula_string[pos] == '(':
				cap_shape, cap_shape_len = parse_tuple(formula_string, pos)
				if cap_shape is None:
					is_valid = False
					break
				cap_part.append(cap_shape)
				pos += cap_shape_len
			result.append(cap_part)
		elif formula_string.startswith('emb', pos):
			if len(result) % 2 != 0:
				print(f"Error parsing input formula: Expected operator, found 'emb'")
				is_valid = False
				break
			emb_part = ['emb']
			pos += 3
			if pos < len(formula_string) and formula_string[pos] == '(':
				emb_shape, emb_shape_len = parse_tuple(formula_string, pos)
				if emb_shape is None:
					is_valid = False
					break
				emb_part.append(emb_shape)
				pos += emb_shape_len
			result.append(emb_part)
		elif formula_string.startswith('+', pos):
			if len(result) % 2 != 1:
				print(f"Error parsing input formula[{pos}]: Expected operand, found '+'")
				is_valid = False
				break
			result.append(['+'])
			pos += 1
		elif formula_string.startswith('-', pos):
			if len(result) % 2 != 1:
				print(f"Error parsing input formula[{pos}]: Expected operand, found '-'")
				is_valid = False
				break
			result.append(['-'])
			pos += 1
		elif formula_string.startswith('*', pos):
			if len(result) % 2 != 1:
				print(f"Error parsing input formula[{pos}]: Expected operand, found '*'")
				is_valid = False
				break
			result.append(['*'])
			pos += 1
		elif formula_string.startswith('/', pos):
			if len(result) % 2 != 1:
				print(f"Error parsing input formula[{pos}]: Expected operand, found '/'")
				is_valid = False
				break
			result.append(['/'])
			pos += 1
		elif formula_string.startswith('.', pos):
			if len(result) % 2 != 1:
				print(f"Error parsing input formula[{pos}]: Expected operand, found '.'")
				is_valid = False
				break
			pos += 1
			if formula_string.startswith('{', pos):
				pos += 1
				parts = formula_string[pos:].partition('}')
				if parts[1] != '}':
					print(f"Error parsing input formula[{pos}]: No closing '}}' in operator parameters '{{{parts[0]}'")
					is_valid = False
					break
				param_string = parts[0]
				try:
					op_param = int(param_string)
				except ValueError:
					print(f"Error parsing input formula[{pos}]: Could not convert operator parameter to int: '{param_string}'")
					is_valid = False
					break
				pos += len(param_string) + 1
				result.append(['.', {'axis': op_param}])
			else:
				result.append(['.'])
		elif formula_string.startswith('(', pos):
			if len(result) % 2 != 0:
				print(f"Error parsing input formula[{pos}]: Expected operator, found subexpression")
				is_valid = False
				break
			pos += 1
			sub_part, sub_len, is_valid = parse_input_formula(formula_string, pos, level + 1)
			if not is_valid:
				break
			pos += sub_len
			if pos >= len(formula_string) or formula_string[pos] != ')':
				print(f"Error parsing input formula[{pos}]: No closing ')' in subexpression '({formula_string[pos - sub_len:]}'")
				is_valid = False
				break
			pos += 1
			result.append(sub_part)
		elif formula_string.startswith(')', pos):
			if level == 0:
				print(f"Error parsing input formula[{pos}]: Found ')' at top level")
				is_valid = False
			break
		else:
			print(f"Error parsing input formula[{pos}]: Unrecognized token: {formula_string[pos:]}")
			is_valid = False
			break
	result_len = pos - start
	
	return result, result_len, is_valid

def parse_tuple(string, start, to_type=int):
	'''
	Parses a tuple of the form (42,69,420) that occurs in the string at position start.
	Returns a tuple of items cast to to_type corresponding to the parsed tuple and the length of the string taken up by the tuple
	'''
	if string[start] != '(':
		return None, 0
	
	parts = string[start + 1:].partition(')')
	if parts[1] != ')':
		print(f"No closing ')' in tuple '({parts[0]}'")
		return None, 0
	contents = parts[0]
	result = None
	try:
		result = tuple(map(to_type, contents.split(',')))
	except ValueError:
		print(f"Could not convert all members to {to_type} in tuple '({parts[0]})'")
		return None, 0
	return result, len(contents) + 2
