from typing import Callable, Any, Tuple, Sequence
import operator
from functools import partial
import contextlib
import torch
from torch.func import grad, jvp, vjp, functional_call, jacrev, jacfwd
from torch.distributions import Normal
from optree import tree_map, tree_map_, tree_reduce, tree_flatten, tree_leaves
from optree.integration.torch import tree_ravel
from posteriors import tree_size, hvp
from posteriors.types import TensorTree, ForwardFn, Tensor
def _vdot_real_part(x: Tensor, y: Tensor) -> float:
    """Vector dot-product guaranteed to have a real valued result despite
    possibly complex input. Thus neglects the real-imaginary cross-terms.

    Args:
        x: First tensor in the dot product.
        y: Second tensor in the dot product.

    Returns:
        The result vector dot-product, a real float
    """
    # all our uses of vdot() in CG are for computing an operator of the form
    #  z^H M z
    #  where M is positive definite and Hermitian, so the result is
    # real valued:
    # https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
    real_part = torch.vdot(x.real.flatten(), y.real.flatten())
    if torch.is_complex(x) or torch.is_complex(y):
        imag_part = torch.vdot(x.imag.flatten(), y.imag.flatten())
        return real_part + imag_part
    return real_part


def _vdot_real_tree(x, y) -> TensorTree:
    return sum(tree_leaves(tree_map(_vdot_real_part, x, y)))


def _mul(scalar, tree) -> TensorTree:
    return tree_map(partial(operator.mul, scalar), tree)


_add = partial(tree_map, operator.add)
_sub = partial(tree_map, operator.sub)


def _identity(x):
    return x

def ggnvp(
    forward: Callable,
    loss: Callable,
    primals: tuple,
    tangents: tuple,
    forward_has_aux: bool = False,
    loss_has_aux: bool = False,
    normalize: bool = True,
) -> (
    Tuple[float, TensorTree]
    | Tuple[float, TensorTree, Any]
    | Tuple[float, TensorTree, Any, Any]
):
    """Generalised Gauss-Newton vector product.
    Equivalent to the (non-empirical) Fisher vector product when `loss` is the negative
    log likelihood of an exponential family distribution as a function of its natural
    parameter.
    Defined as
    $$
    G(θ) = J_f(θ) H_l(z) J_f(θ)^T
    $$
    where $z = f(θ)$ is the output of the forward function $f$ and $l(z_i)$
    is the scalar output of the loss function.
    Thus $J_f(θ)$ is the Jacobian of the forward function $f$ evaluated
    at `primals` $θ$, with dimensions `(dz, dθ)`.
    And $H_l(z)$ is the Hessian of the loss function $l$ evaluated at `z = f(θ)`, with
    dimensions `(dz, dz)`.
    More info on Fisher and GGN matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).

    Args:
        forward: A function with tensor output.
        loss: A function that maps the output of forward to a scalar output.
        primals: Tuple of e.g. tensor or dict with tensor values to evaluate f at.
        tangents: Tuple matching structure of primals.
        forward_has_aux: Whether forward returns auxiliary information.
        loss_has_aux: Whether loss returns auxiliary information.
        normalize: Whether to normalize, divide by the dimension of the output from f.
    Returns:
        Returns a (output, ggnvp_out) tuple containing the output of func evaluated at
            primals and the GGN-vector product. If forward_has_aux or loss_has_aux is
            True, then instead returns a (output, ggnvp_out, aux) or
            (output, ggnvp_out, forward_aux, loss_aux) tuple accordingly.
            output is a tuple of (forward(primals), grad(loss)(forward(primals))).
    """

    jvp_output = jvp(forward, primals, tangents, has_aux=forward_has_aux)
    z = jvp_output[0]
    Jv = jvp_output[1]
    HJv_output = hvp(loss, (z,), (Jv,), has_aux=loss_has_aux)
    HJv = HJv_output[1]

    if normalize:
        output_dim = tree_flatten(jvp_output[0])[0][0].shape[0]
        HJv = tree_map(lambda x: x / output_dim, HJv)

    forward_vjp = vjp(forward, *primals, has_aux=forward_has_aux)[1]
    JTHJv = forward_vjp(HJv)[0]

    return (jvp_output[0], HJv_output[0]), JTHJv, *jvp_output[2:], *HJv_output[2:]

def _hess_and_jac_for_ggn(
    flat_params_to_forward,
    loss,
    argnums,
    forward_has_aux,
    loss_has_aux,
    normalize,
    flat_params,
) -> Tuple[Tensor, Tensor, list]:
    jac_output = jacrev(
        flat_params_to_forward, argnums=argnums, has_aux=forward_has_aux
    )(flat_params)
    jac = jac_output[0] if forward_has_aux else jac_output  # (..., dθ)
    jac = torch.stack(tree_leaves(jac))[
        0
    ]  # convert to tensor (assumes jac has tensor output)
    rescale = 1 / jac.shape[0] if normalize else 1  #  maybe normalize by batchsize
    jac = jac.flatten(end_dim=-2)  # (d, dθ)

    z = flat_params_to_forward(flat_params)
    z = z[0] if forward_has_aux else z

    hess_output = jacfwd(jacrev(loss, has_aux=loss_has_aux), has_aux=loss_has_aux)(z)
    hess = hess_output[0] if loss_has_aux else hess_output
    hess = torch.stack(tree_leaves(hess))[
        0
    ]  # convert to tensor (assumes loss has tensor input)
    z_ndim = hess.ndim // 2
    hess = hess.flatten(start_dim=z_ndim).flatten(
        end_dim=-z_ndim
    )  # flatten to square tensor

    hess *= rescale

    # Collect aux outputs
    aux = []
    if forward_has_aux:
        aux.append(jac_output[1])
    if loss_has_aux:
        aux.append(loss(z)[1])

    return jac, hess, aux

def jac_and_hess(
    forward: Callable,
    loss: Callable,
    argnums: int | Sequence[int] = 0,
    forward_has_aux: bool = False,
    loss_has_aux: bool = False,
    normalize: bool = False,
) -> Callable:
    """
    Constructs function to compute the Generalised Gauss-Newton matrix.

    Equivalent to the (non-empirical) Fisher when `loss` is the negative
    log likelihood of an exponential family distribution as a function of its natural
    parameter.

    Defined as
    $$
    G(θ) = J_f(θ) H_l(z) J_f(θ)^T
    $$
    where $z = f(θ)$ is the output of the forward function $f$ and $l(z)$
    is a loss function with scalar output.

    Thus $J_f(θ)$ is the Jacobian of the forward function $f$ evaluated
    at `primals` $θ$. And $H_l(z)$ is the Hessian of the loss function $l$ evaluated
    at `z = f(θ)`.

    Requires output from `forward` to be a tensor and therefore `loss` takes a tensor as
    input. Although both support `aux` output.

    If `normalize=True`, then $G(θ)$ is divided by the size of the leading dimension of
    outputs from `forward` (i.e. batchsize).

    The GGN will be provided as a square tensor with respect to the
    ravelled parameters.
    `flat_params, params_unravel = optree.tree_ravel(params)`.

    Follows API from [`torch.func.jacrev`](https://pytorch.org/functorch/stable/generated/functorch.jacrev.html).

    More info on Fisher and GGN matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).

    Examples:
        ```python
        from functools import partial
        import torch
        from posteriors import ggn

        # Load model that outputs logits
        # Load batch = {'inputs': ..., 'labels': ...}

        def forward(params, inputs):
            return torch.func.functional_call(model, params, inputs)

        def loss(logits, labels):
            return torch.nn.functional.cross_entropy(logits, labels)

        params = dict(model.parameters())
        ggn_result = ggn(
            partial(forward, inputs=batch['inputs']),
            partial(loss, labels=batch['labels']),
        )(params)
        ```

    Args:
        forward: A function with tensor output.
        loss: A function that maps the output of forward to a scalar output.
            Takes a single input and returns a scalar (and possibly aux).
        argnums: Optional, integer or sequence of integers. Specifies which
            positional argument(s) to differentiate `forward` with respect to.
        forward_has_aux: Whether forward returns auxiliary information.
        loss_has_aux: Whether loss returns auxiliary information.
        normalize: Whether to normalize, divide by the first dimension of the output
            from f.

    Returns:
        A function with the same arguments as f that returns the tensor GGN.
            If has_aux is True, then the function instead returns a tuple of (F, aux).
    """
    assert argnums == 0, "Only argnums=0 is supported for now."

    def internal_ggn(params):
        flat_params, params_unravel = tree_ravel(params)

        def flat_params_to_forward(fps):
            return forward(params_unravel(fps))

        jac, hess, aux = _hess_and_jac_for_ggn(
            flat_params_to_forward,
            loss,
            argnums,
            forward_has_aux,
            loss_has_aux,
            normalize,
            flat_params,
        )

        if aux:
            return (jac, hess), *aux
        else:
            return (jac, hess)

    return internal_ggn

def cg(
    A: Callable,
    b: TensorTree,
    x0: TensorTree = None,
    *,
    maxiter: int = None,
    damping: float = 0.0,
    tol: float = 1e-5,
    atol: float = 0.0,
    M: Callable = _identity,
) -> Tuple[TensorTree, Any]:
    """Use Conjugate Gradient iteration to solve ``Ax = b``.
    ``A`` is supplied as a function instead of a matrix.

    Adapted from [`jax.scipy.sparse.linalg.cg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.sparse.linalg.cg.html).

    Args:
        A:  Callable that calculates the linear map (matrix-vector
            product) ``Ax`` when called like ``A(x)``. ``A`` must represent
            a hermitian, positive definite matrix, and must return array(s) with the
            same structure and shape as its argument.
        b:  Right hand side of the linear system representing a single vector.
        x0: Starting guess for the solution. Must have the same structure as ``b``.
        maxiter: Maximum number of iterations.  Iteration will stop after maxiter
            steps even if the specified tolerance has not been achieved.
        damping: damping term for the mvp function. Acts as regularization.
        tol: Tolerance for convergence.
        atol: Tolerance for convergence. ``norm(residual) <= max(tol*norm(b), atol)``.
            The behaviour will differ from SciPy unless you explicitly pass
            ``atol`` to SciPy's ``cg``.
        M: Preconditioner for A.
            See [the preconditioned CG method.](https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method)

    Returns:
        x : The converged solution. Has the same structure as ``b``.
        info : Placeholder for convergence information.
    """
    if x0 is None:
        x0 = tree_map(torch.zeros_like, b)

    if maxiter is None:
        maxiter = 10 * tree_size(b)  # copied from scipy

    tol *= torch.tensor([1.0])
    atol *= torch.tensor([1.0])

    # tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
    bs = _vdot_real_tree(b, b)
    atol2 = torch.maximum(torch.square(tol) * bs, torch.square(atol))

    def A_damped(p):
        return _add(A(p), _mul(damping, p))

    def cond_fun(value):
        _, r, gamma, _, k = value
        rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
        return (rs > atol2) & (k < maxiter)

    def body_fun(value):
        x, r, gamma, p, k = value
        Ap = A_damped(p)
        alpha = gamma / _vdot_real_tree(p, Ap)
        x_ = _add(x, _mul(alpha, p))
        r_ = _sub(r, _mul(alpha, Ap))
        z_ = M(r_)
        gamma_ = _vdot_real_tree(r_, z_)
        beta_ = gamma_ / gamma
        p_ = _add(z_, _mul(beta_, p))
        return x_, r_, gamma_, p_, k + 1

    r0 = _sub(b, A_damped(x0))
    p0 = z0 = r0
    gamma0 = _vdot_real_tree(r0, z0)
    initial_value = (x0, r0, gamma0, p0, 0)

    value = initial_value

    while cond_fun(value):
        value = body_fun(value)

    x_final, r, gamma, _, k = value
    # compute the final error and whether it has converged.
    rs = gamma if M is _identity else _vdot_real_tree(r, r)
    converged = rs <= atol2

    # additional info output structure
    info = {"error": rs, "converged": converged, "niter": k}

    return x_final, info

def thermo_solve_fvp(
    A: Callable,
    b: TensorTree,
    x0: TensorTree = None,
    *,
    iterations: int = None,
    step: float = 0.1,
    damping: float = 0.0,
    average_regularization: bool = False,
    noise_variance: float = 0.0,
) -> Tuple[TensorTree, Any]:

    if x0 is None:
        x0 = tree_map(torch.zeros_like, b)

    if noise_variance != 0.0:
        diffusion = torch.sqrt(torch.tensor([2 * noise_variance * step]))
    else:
        diffusion = 0
    def cond_fun(value):
        _, k = value
        return (k < iterations)

    if average_regularization:
        def A_damped(p):
            return _add(_mul(1-damping, A(p)), _mul(damping, p))
    else:
        def A_damped(p):
            return _add(A(p), _mul(damping, p))
    def body_fun(value):
        x, k = value
        Ax = A_damped(x)
        x_ = _add(x, _add(_mul(step, _sub(b, Ax)), _mul(diffusion, tree_map(torch.randn_like, x))))
        return x_, k + 1

    initial_value = (x0, 0)
    value = initial_value
    while cond_fun(value):
        value = body_fun(value)
    
    x_final, _ = value
    

    return x_final



def thermo_solve_heun(
    A: Callable,
    b: TensorTree,
    x0: TensorTree = None,
    *,
    iterations: int = None,
    step: float = 0.05,
    damping: float = 0.0
) -> Tuple[TensorTree, Any]:

    if x0 is None:
        x0 = tree_map(torch.zeros_like, b)

    
    def cond_fun(value):
        _, k = value
        return (k < iterations)

    def A_damped(p):
        return _add(A(p), _mul(damping, p))

    def body_fun(value):
        x, k = value
        inc = _sub(b, A_damped(_add(x, _mul(step, _sub(b, A_damped(x))))))
        x_ = _add(x, _mul(step / 2, inc))
        return x_, k + 1

    initial_value = (x0, 0)
    value = initial_value
    while cond_fun(value):
        value = body_fun(value)
    
    x_final, _ = value
    

    return x_final


def woodbury_fisher(
    f: Callable,
    grad: TensorTree,
    argnums: int | Sequence[int] = 0,
    has_aux: bool = False,
    normalize: bool = True,
) -> Callable:
    """
    Constructs function to compute the empirical Fisher information matrix of a function
    f with respect to its parameters, defined as (unnormalized):
    $$
    F(θ) = \\sum_i ∇_θ f_θ(x_i, y_i) ∇_θ f_θ(x_i, y_i)^T
    $$
    where typically $f_θ(x_i, y_i)$ is the log likelihood $\\log p(y_i | x_i,θ)$ of a
    model with parameters $θ$ given inputs $x_i$ and labels $y_i$.

    If `normalize=True`, then $F(θ)$ is divided by the number of outputs from f
    (i.e. batchsize).

    Follows API from [`torch.func.jacrev`](https://pytorch.org/functorch/stable/generated/functorch.jacrev.html).

    More info on empirical Fisher matrices can be found in
    [Martens, 2020](https://jmlr.org/papers/volume21/17-678/17-678.pdf).

    Args:
        f:  A Python function that takes one or more arguments, one of which must be a
            Tensor, and returns one or more Tensors.
            Typically this is the [per-sample log likelihood of a model](https://pytorch.org/tutorials/intermediate/per_sample_grads.html).
        argnums: Optional, integer or sequence of integers. Specifies which
            positional argument(s) to differentiate with respect to. Defaults to 0.
        has_aux: Whether f returns auxiliary information.
        normalize: Whether to normalize, divide by the dimension of the output from f.

    Returns:
        A function with the same arguments as f that returns the empirical Fisher, F.
            If has_aux is True, then the function instead returns a tuple of (F, aux).
    """

    def fisher(*args, **kwargs):
        jac_output = jacrev(f, argnums=argnums, has_aux=has_aux)(*args, **kwargs)
        jac = jac_output[0] if has_aux else jac_output

        # Convert Jacobian to tensor, flat in parameter dimension
        jac = torch.vmap(lambda x: tree_ravel(x)[0])(jac)

        rescale = 1 / jac.shape[0] if normalize else 1

        if has_aux:
            return jac @ jac.T * rescale, jac_output[1]
        else:
            return jac @ jac.T * rescale

    return fisher


def adjust_damping(damping, delta_loss, ngrad, grad, fvp_fun, eps=1/4, factor=3/2):

    actual_reduction = delta_loss
    predicted_reduction = - _vdot_real_tree(grad, ngrad) - 0.5 * _vdot_real_tree(
        ngrad, fvp_fun(ngrad)
    )
    rho = actual_reduction / predicted_reduction
    if rho < eps:
        damping *= factor
    elif rho > 1 - eps:
        damping /= factor

    return damping