import ast
import os
import random
import string
from concurrent.futures import ProcessPoolExecutor
from argparse import ArgumentParser
import pandas as pd

class VariableRenamer:
	def __init__(self, p=0.5, mode="one"):
		self.p = p
		self.mode = mode
		self.available_names = set(string.ascii_lowercase)

	def rename(self, tree):
		all_vars = set()
		var_positions = {}  # var_name -> list of ast.Name nodes

		for node in ast.walk(tree):
			if isinstance(node, ast.Name):
				var_name = node.id
				all_vars.add(var_name)
				var_positions.setdefault(var_name, []).append(node)

		if not all_vars:
			return tree

		if self.mode == "all":
			rename_candidates = list(all_vars)
		elif self.mode == "one":
			rename_candidates = [random.choice(list(all_vars))]
		else:  # "proba"
			rename_candidates = [v for v in all_vars if random.random() < self.p]

		used_names = all_vars.copy()
		var_map = {}
		for var in rename_candidates:
			possible_names = list(self.available_names - used_names)
			if not possible_names:
				break
			new_name = random.choice(possible_names)
			var_map[var] = new_name
			used_names.add(new_name)

		for old_name, new_name in var_map.items():
			for node in var_positions[old_name]:
				node.id = new_name

		return tree

class NeutralOpAdder:
	def __init__(self, p=0.5, mode="proba"):
		self.p = p
		self.mode = mode

	def transform(self, tree):
		eligible_nodes = []

		# First and only pass: find eligible Assign nodes
		for node in ast.walk(tree):
			if isinstance(node, ast.Assign):
				if (
					len(node.targets) == 1 and 
					isinstance(node.value, (ast.Constant, ast.Name, ast.UnaryOp))
				):
					eligible_nodes.append(node)

		# If no eligible nodes found, return False
		if not eligible_nodes:
			return False

		# Select nodes to transform based on mode
		if self.mode == "all":
			to_transform = eligible_nodes
		elif self.mode == "one":
			to_transform = [random.choice(eligible_nodes)] if eligible_nodes else []
		else:  # "proba"
			to_transform = [node for node in eligible_nodes if random.random() < self.p]

		# Apply neutral operations
		for node in to_transform:
			if random.choice([True, False]):
				node.value = ast.BinOp(left=node.value, op=ast.Add(), right=ast.Constant(value=0))
			else:
				node.value = ast.BinOp(left=node.value, op=ast.Sub(), right=ast.Constant(value=0))

		return tree

class ArithmeticCommutativityTransformer:
	def __init__(self, p=0.5, mode="one"):
		self.p = p
		self.mode = mode

	def transform(self, tree):
		commutable_nodes = []

		# One-pass: collect BinOp nodes with Add
		for node in ast.walk(tree):
			if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add):
				commutable_nodes.append(node)

		 # If no eligible nodes found, return False
		if not commutable_nodes:
			return False
		# Choose nodes to transform
		if self.mode == "all":
			to_transform = commutable_nodes
		elif self.mode == "one":
			to_transform = [random.choice(commutable_nodes)] if commutable_nodes else []
		else:  # mode == "proba"
			to_transform = [node for node in commutable_nodes if random.random() < self.p]

		
		for node in to_transform:
			left, right = node.left, node.right
			if isinstance(left, ast.UnaryOp) and isinstance(left.op, ast.USub) and isinstance(left.operand, ast.Constant):
				node.op = ast.Sub()
				node.left = right
				node.right = left.operand  
			else:
				node.left, node.right = right, left

		return tree

class ComparisonSymmetryTransformer:
	def __init__(self, p=0.5, mode="one"):
		self.p = p
		self.mode = mode

	def transform(self, tree):
		commutable_nodes = []

		# Collect eligible comparison nodes in one pass
		for node in ast.walk(tree):
			if (
				isinstance(node, ast.Compare) and 
				len(node.ops) == 1 and 
				len(node.comparators) == 1 and 
				isinstance(node.ops[0], (ast.Lt, ast.Gt, ast.LtE, ast.GtE))
			):
				commutable_nodes.append(node)
		# If no eligible nodes found, return False
		if not commutable_nodes:
			return False
		# Choose nodes to transform
		if self.mode == "all":
			to_transform = commutable_nodes
		elif self.mode == "one":
			to_transform = [random.choice(commutable_nodes)] if commutable_nodes else []
		else:  # "proba"
			to_transform = [node for node in commutable_nodes if random.random() < self.p]

		# Perform the transformation
		for node in to_transform:
			op = node.ops[0]
			left = node.left
			right = node.comparators[0]

			if isinstance(op, ast.Lt):
				new_op = ast.Gt()
			elif isinstance(op, ast.Gt):
				new_op = ast.Lt()
			elif isinstance(op, ast.LtE):
				new_op = ast.GtE()
			elif isinstance(op, ast.GtE):
				new_op = ast.LtE()
			else:
				continue  # not expected to happen

			node.left = right
			node.comparators[0] = left
			node.ops[0] = new_op

		return tree

class NeutralAssignmentInserter:
	def __init__(self):
		self.insertion_sites = []
		self.used_names = set()

	def get_used_names(self, tree):
		for node in ast.walk(tree):
			if isinstance(node, ast.Name):
				self.used_names.add(node.id)
			elif isinstance(node, ast.arg):
				self.used_names.add(node.arg)

	def collect_all_insertion_sites(self, node):
		if hasattr(node, "body") and isinstance(node.body, list):
			for i in range(len(node.body) + 1):
				self.insertion_sites.append((node.body, i))
		for child in ast.iter_child_nodes(node):
			self.collect_all_insertion_sites(child)

	def make_assignment(self):
		available_names = list(set(string.ascii_lowercase) - self.used_names)
		if not available_names:
			return None
		var_name = random.choice(available_names)
		value = random.randint(-99, 99)
		return ast.Assign(
			targets=[ast.Name(id=var_name, ctx=ast.Store())],
			value=ast.Constant(value=value)
		)

	def insert_assignment(self, tree):
		self.get_used_names(tree)
		self.collect_all_insertion_sites(tree)
		if not self.insertion_sites:
			return False
		target_body, index = random.choice(self.insertion_sites)
		assign = self.make_assignment()
		if assign:
			ast.fix_missing_locations(assign)
			target_body.insert(index, assign)
		return tree

class NeutralExpressionInserter:
	def __init__(self):
		self.insertion_sites = []
		self.used_names = set()

	def get_used_names(self, tree):
		for node in ast.walk(tree):
			if isinstance(node, ast.Name):
				self.used_names.add(node.id)
			elif isinstance(node, ast.arg):
				self.used_names.add(node.arg)

	def collect_all_insertion_sites(self, node):
		if hasattr(node, "body") and isinstance(node.body, list):
			for i in range(len(node.body) + 1):
				self.insertion_sites.append((node.body, i))
		for child in ast.iter_child_nodes(node):
			self.collect_all_insertion_sites(child)

	def make_expression(self):
		available_names = list(set(string.ascii_lowercase) - self.used_names)
		if not available_names:
			return None
		var_name = random.choice(available_names)
		left_num = random.randint(-99, 99)
		right_num = random.randint(-99, 99)
		if right_num < 0:
			op = ast.Add()
		else:
			op = random.choice([ast.Add(), ast.Sub()])
		binop = ast.BinOp(
			left=ast.Constant(value=left_num),
			op=op,
			right=ast.Constant(value=right_num)
		)
		return ast.Assign(
			targets=[ast.Name(id=var_name, ctx=ast.Store())],
			value=binop
		)

	def insert_expression(self, tree):
		self.insertion_sites = []
		self.used_names = set()
		self.get_used_names(tree)
		self.collect_all_insertion_sites(tree)
		if not self.insertion_sites:
			return False
		target_body, index = random.choice(self.insertion_sites)
		expr = self.make_expression()
		if expr:
			ast.fix_missing_locations(expr)
			target_body.insert(index, expr)
		return tree

# --- Transformation dispatcher ---

def transform_snippet_with_args(args):
	snippet, perturbation, perturbation_mode, proba = args
	try:
		tree = ast.parse(snippet)
	except Exception as e:
		return f"Could not parse snippet:\n{snippet}\nError: {e}\n", snippet

	if perturbation == 1:
		tree = VariableRenamer(p=proba,mode=perturbation_mode).rename(tree)
	elif perturbation == 2:
		tree = NeutralOpAdder(p=proba,mode=perturbation_mode).transform(tree)
		if tree is False:  # No eligible nodes found
			return False
	elif perturbation == 3:
		tree = ArithmeticCommutativityTransformer(p=proba,mode=perturbation_mode).transform(tree)
		if tree is False:  # No eligible nodes found
			return False
	elif perturbation == 4:
		tree = ComparisonSymmetryTransformer(p=proba,mode=perturbation_mode).transform(tree)
		if tree is False:  # No eligible nodes found
			return False
	elif perturbation == 5:
		tree = NeutralAssignmentInserter().insert_assignment(tree)
		if tree is False:  # No eligible nodes found
			return False
	elif perturbation == 6:
		tree = NeutralExpressionInserter().insert_expression(tree)
		if tree is False:  # No eligible nodes found
			return False
	else:
		return f"Unknown mode: {perturbation_mode}", snippet

	try:
		new_code = ast.unparse(tree)
	except Exception as e:
		new_code = f"Could not unparse transformed AST: {e}"
	return new_code

# --- Batch processor ---
def process_batch(batch_args):
	return [transform_snippet_with_args(args) for args in batch_args]

# --- Batching logic ---
def batch_snippets(snippets, batch_size):
	for i in range(0, len(snippets), batch_size):
		yield snippets[i:i+batch_size]


def main():
	# perturbation_names = {
	# 	1: "variable_renaming",
	# 	2: "neutral_operator",
	# 	3: "arithmetic_commutativity",
	# 	4: "comparison_commutativity",
	# 	5: "neutral_assignment",
	# 	6: "neutral_expression"
	# }

	# perturbation_modes = {
	# 	1: "all",
	# 	2: "one",
	# 	3: "proba"
	# }

	# print("Choose a transformation mode:")
	# print(" 1 = Variable Renaming")
	# print(" 2 = Neutral Operator (+/- 0)")
	# print(" 3 = Arithmetic Commutativity")
	# print(" 4 = Comparison Commutativity")
	# print(" 5 = Neutral Assignment Insertion")
	# print(" 6 = Neutral Expression Insertion")
	# while True:
	#	 try:
	#		 perturbation = int(input("Enter perturbation number (1-6): ").strip())
	#		 if perturbation not in range(1, 7):
	#			 raise ValueError
	#		 break
	#	 except ValueError:
	#		 print("Please enter a valid number between 1 and 6.")
	# if perturbation not in (5, 6):
	#	 print("Choose a perturbation mode:")
	#	 print(" 1 = All Variables")
	#	 print(" 2 = One Variable")
	#	 print(" 3 = Probability")
	#	 while True:
	#		 try:
	#			 perturbation_mode_number = int(input("Enter perturbation mode (1-3): ").strip())
	#			 if perturbation_mode_number not in range(1, 4):
	#				 raise ValueError
	#			 perturbation_mode = perturbation_modes[perturbation_mode_number]
	#			 break
	#		 except ValueError:
	#			 print("Please enter a valid number between 1 and 3.")

	#	 # Perturbations 1-4 may require a probability if the mode is "proba"
	#	 if perturbation_mode_number == 3:
	#		 if perturbation in (1, 2, 3, 4):
	#			 while True:
	#				 try:
	#					 proba_input = input("Enter probability 0~1 [Default is 1.0] :").strip()
	#					 if proba_input == "":
	#						 proba = 1.0
	#						 break
	#					 proba = float(proba_input)
	#					 if not (0.0 <= proba <= 1.0):
	#						 raise ValueError
	#					 break
	#				 except ValueError:
	#					 print("Please enter a valid probability between 0.0 and 1.0.")
	#	 else:
	#		 proba = 1.0  # Not used, but passed for signature
	# else:
	#	 proba = 1.0  # Not used, but passed for signature
	#	 perturbation_mode = "one"


	def print_progress_bar(iteration, total, prefix='', suffix='', length=40, fill='█'):
		percent = f"{100 * (iteration / float(total)):.1f}"
		filled_length = int(length * iteration // total)
		bar = fill * filled_length + '-' * (length - filled_length)
		print(f'\r{prefix} |{bar}| {percent}% {suffix}', end='\r')
		if iteration == total:
			print()

	# Loading snippets
	print('Reading the csv file ...')
	df = pd.read_csv(input_path)
	total = len(df)
	snippets = []
	for i, input_prompt in enumerate(df["example_input"]):
		snippets.append(input_prompt.split("\n#STEP")[0])
		print_progress_bar(i + 1, total, prefix='Loading snippets', suffix='Done')

	print(f"Loaded {len(snippets)} snippets.\n")

	if not (perturbation in (3, 4) and perturbation_mode == "all"):
		snippets = [snippet for snippet in snippets for _ in range(duplication_factor)]
		print(f"Duplicated each snippet {duplication_factor} times. Total snippets: {len(snippets)}\n")

	args_list = [(snippet, perturbation, perturbation_mode, proba) for snippet in snippets]
	batch_size = duplication_factor * 6 
	batches = list(batch_snippets(args_list, batch_size))

	print(f"\nProcessing {len(snippets)} snippets in {len(batches)} batches...\n")

	results = []
	total_batches = len(batches)
	with ProcessPoolExecutor(max_workers=32) as executor:
		batch_results = executor.map(process_batch, batches)
		for idx, batch in enumerate(batch_results, 1):
			results.extend(batch)
			print_progress_bar(idx, total_batches, prefix='Processing', suffix='Done')

	trans_set = set()
	trans_list = list()
	valid_results = [trans for trans in results if trans is not False]
	for trans in valid_results:
		trans = trans.strip()
		if len(trans.split("\n")) > max_code_length:
			continue
		if trans in trans_set:
			continue
		trans_set.add(trans)
		trans = trans.replace("    ", "\t")
		if trans in snippets:
			continue
		trans_list.append(trans)
	
	with open(output_path, "w") as out_f:
		for trans in trans_list:
			out_f.write(trans + "\n\n")

	print(f"Transformed snippets written to {output_path}")


args = ArgumentParser()
args.add_argument("perturbation", type=int)
args.add_argument("perturbation_mode", type=str)
args.add_argument("input_path", type=str)
args.add_argument("output_path", type=str)
args.add_argument("proba", type=float)
args.add_argument("duplication_factor", type=int)
args.add_argument("max_code_length", type=int)
parser = args.parse_args()
perturbation = parser.perturbation
perturbation_mode = parser.perturbation_mode
input_path = parser.input_path
output_path = parser.output_path
proba = parser.proba
duplication_factor = parser.duplication_factor
max_code_length = parser.max_code_length

main()