# NOTE: loss function is defined as v = loss_fn(cir, params=None, **loss_fn_kwargs)
from typing import Callable, Optional
import numpy as np
from project_qsl.circuit import QCircuit

__all__ = ["finite_diff1", "finite_diff2", "paramshift_grad"]


def finite_diff1(
    cir: QCircuit,
    loss_fn: Callable,
    step_size: float = 1e-3,
    params: Optional[np.ndarray] = None,
    loss_fn_args=()
) -> "np.ndarray[float]":
    # forward, 1st order
    if params is None:
        params = cir.get_params()
    else:
        cir.set_params(params)

    v0 = loss_fn(cir, *loss_fn_args)
    dloss = []
    for i in range(cir.num_params):
        params[i] += step_size
        cir.set_params(params)
        v1 = loss_fn(cir, *loss_fn_args)
        dloss.append((v1 - v0) / step_size)
        params[i] -= step_size
    cir.set_params(params)
    return np.asarray(dloss)


def finite_diff2(
    cir: QCircuit,
    loss_fn: Callable,
    step_size: float = 1e-3,
    params: Optional[np.ndarray] = None,
    loss_fn_args=()
) -> "np.ndarray[float]":
    # central, 2nd order
    if params is None:
        params = cir.get_params()
    else:
        cir.set_params(params)

    vc = loss_fn(cir, *loss_fn_args)
    dloss = []
    for i in range(cir.num_params):
        params[i] += step_size
        cir.set_params(params)
        vf = loss_fn(cir, *loss_fn_args)
        params[i] -= 2*step_size
        cir.set_params(params)
        vb = loss_fn(cir, *loss_fn_args)
        dloss.append((vf - 2*vc + vb) / step_size**2)
        params[i] += step_size
    cir.set_params(params)
    return np.asarray(dloss)


def paramshift_grad(cir: QCircuit, loss_fn: Callable, params: Optional[np.ndarray] = None, loss_fn_args=()) -> "np.ndarray[float]":
    if params is None:
        params = cir.get_params()

    dloss = []
    for i in range(cir.num_params):
        params[i] += np.pi/2
        cir.set_params(params)
        vf = loss_fn(cir, *loss_fn_args)
        params[i] -= np.pi
        cir.set_params(params)
        vb = loss_fn(cir, *loss_fn_args)
        dloss.append(0.5*(vf - vb))
        params[i] += np.pi/2
    cir.set_params(params)
    return np.asarray(dloss)
