import torch

# from numpy.random import gamma
from torch.optim import Optimizer


class H_SA_SGHMC(Optimizer):
    """Reproduced from https://github.com/JavierAntoran/Bayesian-Neural-Networks
    Stochastic Gradient Hamiltonian Monte-Carlo Sampler that uses scale adaption
    during burn-in procedure to find some hyperparamters."""

    def __init__(self, params, lr=1e-2, base_C=0.05):

        self.eps = 1e-6

        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if base_C < 0:
            raise ValueError("Invalid friction term: {}".format(base_C))

        defaults = dict(
            lr=lr,
            base_C=base_C,
        )
        super(H_SA_SGHMC, self).__init__(params, defaults)

    def step(self, burn_in=False, resample_momentum=False, resample_prior=False):
        """Simulate discretized Hamiltonian dynamics for one step"""
        loss = None

        for (
            group
        ) in (
            self.param_groups
        ):  # iterate over blocks -> the ones defined in defaults. We dont use groups.
            for p in group["params"]:  # these are weight and bias matrices
                if p.grad is None:
                    continue
                state = self.state[p]  # define dict for each individual param
                if len(state) == 0:
                    state["iteration"] = 0
                    state["tau"] = torch.ones_like(p)
                    state["g"] = torch.ones_like(p)
                    state["V_hat"] = torch.ones_like(p)
                    state["v_momentum"] = torch.zeros_like(p)

                state[
                    "iteration"
                ] += 1  # this is kind of useless now but lets keep it provisionally

                # if resample_prior:
                #     alpha = self.alpha0 + p.data.nelement() / 2
                #     beta = self.beta0 + (p.data ** 2).sum().item() / 2
                #     gamma_sample = gamma(shape=alpha, scale=1 / (beta), size=None)
                #     #                     print('std', 1/np.sqrt(gamma_sample))
                #     state['weight_decay'] = gamma_sample

                base_C, lr = group["base_C"], group["lr"]
                # weight_decay = state["weight_decay"]
                tau, g, V_hat = state["tau"], state["g"], state["V_hat"]

                d_p = p.grad.data
                # if weight_decay != 0:
                #     d_p.add_(weight_decay, p.data)

                # update parameters during burn-in
                if burn_in:  # We update g first as it makes most sense
                    tau.add_(
                        -tau * (g**2) / (V_hat + self.eps) + 1
                    )  # specifies the moving average window, see Eq 9 in [1] left
                    tau_inv = 1.0 / (tau + self.eps)
                    g.add_(
                        -tau_inv * g + tau_inv * d_p
                    )  # average gradient see Eq 9 in [1] right
                    V_hat.add_(
                        -tau_inv * V_hat + tau_inv * (d_p**2)
                    )  # gradient variance see Eq 8 in [1]

                V_sqrt = torch.sqrt(V_hat)
                V_inv_sqrt = 1.0 / (V_sqrt + self.eps)  # preconditioner

                if (
                    resample_momentum
                ):  # equivalent to var = M under momentum reparametrisation
                    state["v_momentum"] = torch.normal(
                        mean=torch.zeros_like(d_p),
                        std=torch.sqrt((lr**2) * V_inv_sqrt),
                    )
                v_momentum = state["v_momentum"]

                noise_var = 2.0 * (lr**2) * V_inv_sqrt * base_C - (lr**4)
                noise_std = torch.sqrt(torch.clamp(noise_var, min=1e-16))
                # sample random epsilon
                noise_sample = torch.normal(
                    mean=torch.zeros_like(d_p), std=torch.ones_like(d_p) * noise_std
                )

                # update momentum (Eq 10 right in [1])
                v_momentum.add_(
                    -(lr**2) * V_inv_sqrt * d_p - base_C * v_momentum + noise_sample
                )

                # update theta (Eq 10 left in [1])
                p.data.add_(v_momentum)

        return loss
