import numpy as np
from typing import Callable, Optional
import sys

class IterativeOptimizer:
	def __init__(self, function:Callable, update:Callable, enforcer:Optional[Callable]=None):
		# Let theta be a datatype that supports addition / multiplication (i.e. an algebraic field)
		# This class updates theta to optimize a callable function(theta)->scalar according to the following rules
		# theta = theta + update(theta)
		# After each update, you may optionally enforce constraints on theta using "enforcer"

		self.update = update
		self.function = function
		self.enforcer = enforcer

	def optimize(self, theta, eta, beta = 0, max_iters = 1000, compute_function = False, verbose = False):
		function_values = []
		if compute_function:
			function_values = np.zeros(max_iters)

		v = 0 # acceleration parameter
		for i in range(max_iters):

			# compute the function
			if compute_function:
				function_values[i] = self.function(theta)

			# do the update
			v = beta * v + (1 - beta) * self.update(theta)
			theta = theta - eta*v

			# enforce constraints
			if self.enforcer is not None:
				theta = self.enforcer(theta)

			# verbose output
			if verbose:
				if i%20 == 0:
					print(f"{i:d}: {function_values[i]:.3E}")
					print(f"{theta[-1]}")
					sys.stdout.flush()

		return (theta, function_values)

	# def optimizeNonDescent(self, theta, eta, beta = 0, max_iters = 1000):
	# 	function_values = np.zeros(max_iters)

	# 	vmin = float('inf')
	# 	theta_min = np.copy(theta)
	# 	v = 0 # acceleration parameter
	# 	for i in range(max_iters):

	# 		# compute the function
	# 		function_values[i] = self.function(theta)
	# 		if function_values[i] < vmin:
	# 			vmin = function_values[i]
	# 			theta_min = theta

	# 		# do the update
	# 		v = beta * v + (1 - beta) * self.update(theta)
	# 		theta = theta - eta*v

	# 		# enforce constraints
	# 		if self.enforcer is not None:
	# 			theta = self.enforcer(theta)

	# 	return (theta_min, theta, function_values)




