"Differentiable Functions"

import random
import numpy as np
import torch

# constant
torch.pi = torch.acos(torch.tensor(-1.))

def rosenbrock(tensor, lib=torch):
    # https://en.wikipedia.org/wiki/Test_functions_for_optimization
    if type(tensor) is list:
        x, y = tensor
    else:
        x, y = lib.split(tensor, (1, 1))
    return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2

def rosenbrock_grad(tensor, lib=torch):
    # https://en.wikipedia.org/wiki/Test_functions_for_optimization
    if type(tensor) is list:
        x, y = tensor
    else:
        x, y = lib.split(tensor, (1, 1))
    return torch.cat([-(1 - x) * 2 - 100 * (y - x ** 2) * 4 * x,
                      100 * (y - x ** 2) * 2])

def rastrigin(tensor, lib=torch):
    # https://en.wikipedia.org/wiki/Test_functions_for_optimization
    if type(tensor) is list:
        x, y = tensor
    else:
        x, y = lib.split(tensor, (1, 1))
    A = 10
    f = (
        A * 2
        + (x ** 2 - A * lib.cos(x * lib.pi * 2))
        + (y ** 2 - A * lib.cos(y * lib.pi * 2))
    )
    return f

def rastrigin_grad(tensor, lib=torch):
    # https://en.wikipedia.org/wiki/Test_functions_for_optimization
    if type(tensor) is list:
        x, y = tensor
    else:
        x, y = lib.split(tensor, (1, 1))
    A = 10
    df = [
        (x * 2 + A * lib.sin(x * lib.pi * 2) * lib.pi * 2),
        (y * 2 + A * lib.sin(y * lib.pi * 2) * lib.pi * 2)
    ]
    return torch.cat(df)


if __name__ == '__main__':
    # set random seed
    seed=2024
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    fun_names = ['rosenbrock', 'rastrigin']
    for i, f_name in enumerate(fun_names):
        print('Testing '+f_name+' function...')
        fun = eval(fun_names[i])
        grad = eval(fun_names[i]+'_grad')
        
        # random point
        z = torch.randn(2)
        f, g = fun(z), grad(z)
        v = torch.randn(2)
        h = 1e-5
        df = (fun(z + h*v) - f)/h
        gv = torch.inner(g, v)
        relerr = torch.abs(df - gv)/torch.linalg.norm(g)
        print('The error between gradient and finite-difference approximation is: {}\n'.format(relerr.item()))