from abc import abstractmethod, ABCMeta
from typing import Callable

import numpy as np


def get_centered_radial_cosine(ord: int = 2, scale: float = 20.0, freq: float = 2.0):
  def f(x):
    r = np.linalg.norm(x, axis=0, ord=ord)
    out = scale * np.cos(freq * (2*np.pi) * r)
    return out
  return f

def get_diagonal_cosine(scale: float = 20.0):
  def f(x):
    r = x[0] + x[1]
    out = scale * np.cos(2 * (2*np.pi) * r)
    return out
  return f

def get_radial_gaussian(mean, std, ord: int = 2):
  def f(x):
    r = np.linalg.norm(x, axis=0, ord=ord)
    out = np.exp(-.5 * ((r - mean) / std)**2)
    out *= (1 / (std * np.sqrt(2*np.pi)))
    return out
  return f

def get_radial_step(radius, m, scale=1., ord: int = 2):
  def f(x):
    r = np.linalg.norm(x, axis=0, ord=ord)
    out = np.tanh(m*((r-radius[0]))/(1-(r-radius[0])**2))
    out = out + (1 - np.tanh(m*((r-radius[1]))/(1-(r-radius[1])**2)))
    out *= scale
    return out
  return f

class FunctionDistribution(metaclass=ABCMeta):
  @abstractmethod
  def draw(self) -> Callable:
    pass

class ConstantFunctionGenerator(FunctionDistribution):
  def __init__(self, func):
    self.func = func

  def draw(self):
    return self.func

class RandomRadialSines(FunctionDistribution):
  def __init__(self, modes: int = 2, ord: int = 2, scale: float = 20.):
    self.modes = modes
    self.ord = ord
    self.scale = scale

  def draw(self):
    center = np.random.uniform(-1, 1, size=(2, 1))
    shifts = np.random.uniform(low=0., high=2*np.pi, size=(self.modes,))
    coeffs = sorted([0.] + list(np.random.uniform(size=(self.modes-1,)))+ [1.])
    coeffs = np.subtract(coeffs[1:], coeffs[:-1])
    def f(x):
      r = np.linalg.norm((x[:2] - center), axis=0, ord=self.ord)
      terms = [
        coeffs[k] * np.sin(2*(k+1)*np.pi*r + shifts[k])
        for k in range(self.modes)
      ]
      return self.scale * np.sum(np.stack(terms), axis=0)
    return f
