import numpy as np
from typing import Callable, Optional


''' There are probably a million other ways to compute the update,
but I just use BallGradApprox for most things

DatasetFunction is a convenience class to package a dataset and a 2-parameter function
This is important (for example) to create a gradient function and a loss function that can be 
optimized by the IterativeOptimizer

'''

class BallGradApprox:
	def __init__(self, function:Callable, sigma:float, n_components:int, lam:float = 0, regularization:Optional[str] = None):
		self.sigma = sigma
		self.n_components = n_components
		self.function = function
		self.regularizer = None
		if regularization is not None:
			regularization = regularization.lower()
			if regularization is 'l2' or regularization is 'ridge':
				self.regularizer = self.ridgeGradient
			elif regularization is 'l1' or regularization is 'lasso':
				self.regularizer = self.lassoGradient

	def ridgeGradient(theta):
		return self.lam*2*theta

	def lassoGradient(theta):
		return self.lam*np.sign(theta)

	def __call__(self, theta):
		d = theta.shape[0]
		directions = np.random.normal(size = (self.n_components,d))
		out = np.zeros_like(theta)
		fx = self.function(theta)
		for di in directions:
			# tn = theta + self.sigma*di
			# tnorm = np.linalg.norm(tn)
			# if tnorm >= 1: 
			# 	tn /= tnorm
			out += (1.0 / (self.n_components * self.sigma)) * ( self.function(theta + self.sigma*di) - fx) * di
		if self.regularizer is not None: 
			out += self.regularizer(theta)
		return out


class DatasetFunction:
	def __init__(self, function:Callable, dataset):
		# packages a function f(arg, dataset) into a single-argument function f_dataset(arg)
		self.function = function
		self.data = dataset

	def __call__(self, theta):
		return self.function(theta,self.data)

class FourierFeatures:
	def __init__(self, d, D):
		self.W = np.random.normal(0,1,size = (D,d))
		self.D = D
		self.d = d
	# this paper was helpful to understand random FFs:
	# https://arxiv.org/pdf/1506.02785.pdf

	def featurize(self, X, intercept = False):
		if len(X.shape) == 2:
			args = np.dot(self.W,X.T).T
			if intercept:
				features = np.empty( (args.shape[0],2*self.D+1), dtype = X.dtype)
				features[:,-1] = 1
			else:
				features = np.empty( (args.shape[0],2*self.D), dtype = X.dtype)
			features[:,0:2*args.shape[1]:2] = np.cos(args)
			features[:,1:2*args.shape[1]:2] = np.sin(args)
		else:
			args = np.dot(self.W,X)
			if intercept:
				features = np.empty(2*self.D+1, dtype = X.dtype)
				features[-1] = 1
			else:
				features = np.empty(2*self.D, dtype = X.dtype)
			features[0:2*self.D:2] = np.cos(args)
			features[1:2*self.D:2] = np.sin(args)

		# if intercept:
		# 	features *= 1.0/np.sqrt(self.D+1)
		# else:
		features *= 1.0/np.sqrt(self.D)

		return features


def linear_regression_loss(theta, data):
	y = data[:,-1]
	X = data[:,:-1]
	yhat = np.dot(X, theta)
	return np.linalg.norm(yhat - y)**2

def surrogate_regression_loss(theta, data, p = 4):
	y = data[:,-1]
	X = data[:,:-1]
	yhat = np.dot(X, theta)
	ip = yhat - y
	ip = np.clip(ip, -0.9999, 0.9999)
	losses = 1/2.0*(1 - 1/np.pi * np.arccos(ip))**p
	losses += 1/2.0*(1 - 1/np.pi * np.arccos(-ip))**p
	return np.sum(losses)

def linear_regression_grad(theta, data):
	y = data[:,-1]
	X = data[:,:-1]
	A = np.dot(X.T,X)
	return 2*np.dot(A,theta) - 2*np.dot(X.T,y)

def surrogate_regression_grad(theta, data, p = 4):
	y = data[:,-1]
	X = data[:,:-1]
	yhat = np.dot(X, theta)
	ip = yhat - y
	ip = np.clip(ip, -0.9999, 0.9999)
	numerator = p*(1 - 1/np.pi * np.arccos(ip))**(p-1) - p*(1 - 1/np.pi * np.arccos(-ip))**(p-1)
	denominator = 1.0/(2*np.pi*np.sqrt(1 - ip**2))
	return np.dot(numerator / denominator, X)

def surrogate_classification_loss(theta, data, p = 2):
	y = data[:,-1]
	X = data[:,:-1]
	yhat = np.dot(X, theta)
	t = -1 * yhat * y
	t = np.clip(t, -0.9999, 0.9999)
	losses = (1 - 1/np.pi * np.arccos(t))**p
	return np.sum(losses)

def surrogate_classification_grad(theta, data, p = 2):
	y = data[:,-1]
	X = data[:,:-1]
	yhat = np.dot(X, theta)
	t = -1 * yhat * y
	t = np.clip(t, -0.9999, 0.9999)
	losses = (1 - 1/np.pi * np.arccos(t))**(p-1)
	return np.sum(losses)



