from typing import Callable, Optional, Dict
import numpy as np
import tqdm
from project_qsl.circuit import QCircuit
from project_qsl.gradient import paramshift_grad
import scipy

__all__ = ["Adam", "SMO", "minimize_adam", "minimize_smo", "minimize_scipyopt"]


class Adam:
    def __init__(
        self,
        loss_fn: Callable,
        grad_fn: Callable,
        beta1: float = 0.1,
        beta2: float = 0.1,
        learning_rate: float = 0.5,
        loss_fn_args=(),
        grad_fn_args=()
    ) -> None:
        self.beta1 = beta1
        self.beta2 = beta2
        self.moment1 = 0.0
        self.moment2 = 0.0
        self.epsilon = 1e-8
        self.num_iters = 0
        self.lr = learning_rate
        self._num_grad_eval = 0
        self.grad_fn = lambda cir: grad_fn(cir, loss_fn, *grad_fn_args, loss_fn_args=loss_fn_args)

    def update(self, cir: QCircuit) -> QCircuit:
        params = cir.get_params()
        df = self.grad_fn(cir)
        self._num_grad_eval += 1
        self.num_iters += 1
        self.moment1 = self.beta1*self.moment1 + (1 - self.beta1)*df
        self.moment2 = self.beta2*self.moment2 + (1 - self.beta2)*(df**2)
        avg_m1 = self.moment1 / (1 - self.beta1**self.num_iters)
        avg_m2 = self.moment2 / (1 - self.beta2**self.num_iters)
        params -= self.lr * avg_m1 / (np.sqrt(avg_m2) + self.epsilon)
        # self.lr *= np.sqrt(1 - self.beta2**self.num_iters) / (1 - self.beta1**self.num_iters)
        cir.set_params(params)
        return cir


def minimize_adam(loss_fn: Callable, cir: QCircuit, grad_fn: Callable = paramshift_grad,
                  max_iters: Optional[int] = None, tol: float = 1e-5,
                  beta1: float = 0.9, beta2: float = 0.999, learning_rate: float = 0.1,
                  return_all: bool = False, grad_fn_args=(), loss_fn_args=()):
    v0 = loss_fn(cir, *loss_fn_args)
    delta = np.inf
    if max_iters is None:
        max_iters = cir.num_params*200
    if return_all:
        param_history = [cir.get_params()]
        loss_history = [v0]
    else:
        param_history = []
        loss_history = []
    converged = False

    adam = Adam(loss_fn, grad_fn, beta1, beta2, learning_rate, loss_fn_args, grad_fn_args)
    
    with tqdm.trange(max_iters) as pbar:
        for k in pbar:
            adam.update(cir)
            v1 = loss_fn(cir, *loss_fn_args)
            delta = abs(v1 - v0)
            v0 = v1

            if return_all:
                param_history.append(cir.get_params())
                loss_history.append(v0)

            if delta <= tol:
                converged = True
                break
        
            pbar.set_postfix({"loss": f"{v0:.5E}"})
    
    if converged:
        print("Convergence reached!")
        print(f"Number of optimization steps: {k+1:d},")
        print(f"Number of gradient evaluations: {adam._num_grad_eval:d},")
        print(f"Optimal parameters: {cir.get_params().tolist()},")
        print(f"Optimal function value: {v0:.5f}.")
    else:
        print("Maximum number of iterations has been reached!")
        print(f"|f1-f0|={delta:.5f}.")
    
    return v0, cir, param_history, loss_history
    

# NOTE: this optimizer may not be a general optimizer, it may only effective on <H> VQE.
# https://arxiv.org/pdf/1903.12166.pdf
class SMO:
    def __init__(self, loss_fn: Callable, loss_fn_args=()) -> None:
        self.loss_fn = lambda cir: loss_fn(cir, *loss_fn_args)
        self._num_fn_eval = 0
    
    def update(self, cir: QCircuit) -> QCircuit:
        params = cir.get_params()
        v0 = self.loss_fn(cir)
        for i in range(cir.num_params):
            params[i] += np.pi/2
            cir.set_params(params)
            vp = self.loss_fn(cir)
            params[i] -= np.pi
            cir.set_params(params)
            vm = self.loss_fn(cir)
            phi = np.arctan2(0.5*(vm - vp), v0 - 0.5*(vm + vp))
            params[i] -= (phi + np.pi/2)
        self._num_fn_eval += 2*cir.num_params + 1
        cir.set_params(params)
        return cir


def minimize_smo(loss_fn: Callable, cir: QCircuit, max_iters: Optional[int] = None, tol: float = 1e-5,
                 return_all: bool = False, loss_fn_args=()):
    v0 = loss_fn(cir, *loss_fn_args)
    delta = np.inf
    if max_iters is None:
        max_iters = cir.num_params*200
    if return_all:
        param_history = [cir.get_params()]
        loss_history = [v0]
    else:
        param_history = []
        loss_history = []
    converged = False

    smo = SMO(loss_fn, loss_fn_args)
    
    with tqdm.trange(max_iters) as pbar:
        for k in pbar:
            smo.update(cir)
            v1 = loss_fn(cir, *loss_fn_args)
            delta = abs(v1 - v0)
            v0 = v1

            if return_all:
                param_history.append(cir.get_params())
                loss_history.append(v0)

            if delta <= tol:
                converged = True
                break
        
            pbar.set_postfix({"loss": f"{v0:.5E}"})
    
    if converged:
        print("Convergence reached!")
        print(f"Number of optimization steps: {k+1:d},")
        print(f"Number of function evaluations: {smo._num_fn_eval+k+1:d},")
        print(f"Optimal parameters: {cir.get_params().tolist()},")
        print(f"Optimal function value: {v0:.5f}.")
    else:
        print("Maximum number of iterations has been reached!")
        print(f"|f1-f0|={delta:.5f}.")
    
    return v0, cir, param_history, loss_history


def minimize_scipyopt(
    loss_fn: Callable, x0: np.ndarray, cir: QCircuit, shots: int = 1024, qiskit_backend=None,
    error_mitigator=None, method: str = "COBYLA", max_iters: int = 100, return_all: bool = False,
    if_print: bool = False, tol:float = 1e-8
):
    args = (cir, shots, qiskit_backend, error_mitigator)

    if method not in ["Nelder-Mead", "COBYLA", "SLSQP"]:
        raise ValueError(f"Current method: {method} is not available.")

    v0 = loss_fn(x0, *args)
    callback = None

    if max_iters is None:
        max_iters = cir.num_params*200
    if return_all:
        param_history = [cir.get_params()]
        loss_history = [v0]

        def callback(xk) -> bool:
            loss = loss_fn(xk, *args)
            param_history.append(xk)
            loss_history.append(loss)
            if if_print:
                print(f"loss={loss:.5f}", end="\r")
    else:
        param_history = []
        loss_history = []

    res = scipy.optimize.minimize(loss_fn, x0, args, method=method, tol=tol, options={"maxiter": max_iters}, callback=callback)
    if if_print:
        print(res)
        print("-"*30)

    return res.fun, cir, param_history, loss_history