from math import pow
from typing import Callable, Optional, Tuple
from contextlib import contextmanager
import torch
import torch.optim
import torch.distributed as dist
from torch import Tensor


ClosureType = Callable[[], Tensor]


def _welford_mean(avg: Optional[Tensor], newval: Tensor, count: int) -> Tensor:
    return newval if avg is None else avg + (newval - avg) / count

# extends the IVON algorithm to handle Lagrange multipliers and general prior
class IVON(torch.optim.Optimizer):
    hessian_approx_methods = (
        'price',
        'gradsq',
    )

    def __init__(
        self,
        params,
        lr: float,
        ess: float,
        prior_prec, 
        prior_mean, 
        dual_mean, 
        dual_prec, 
        hess_init: float = 1.0,
        beta1: float = 0.9,
        beta2: float = 0.99999,
        mc_samples: int = 1,
        hess_approx: str = 'price',
        clip_radius: float = float("inf"),
        sync: bool = False,
        debias: bool = True,
        rescale_lr: bool = False,
        local_wd: float = 0.0 # additional L2 regularization, divided by ess so scale is like priorprec
    ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 1 <= mc_samples:
            raise ValueError(
                "Invalid number of MC samples: {}".format(mc_samples)
            )
        if not 0.0 <= hess_init:
            raise ValueError(
                "Invalid Hessian initialization: {}".format(hess_init)
            )
        if not 0.0 <= ess:
            raise ValueError("Invalid effective sample size: {}".format(ess))
        if not 0.0 < clip_radius:
            raise ValueError("Invalid clipping radius: {}".format(clip_radius))
        if not 0.0 <= beta1 <= 1.0:
            raise ValueError("Invalid beta1 parameter: {}".format(beta1))
        if not 0.0 <= beta2 <= 1.0:
            raise ValueError("Invalid beta2 parameter: {}".format(beta2))
        if hess_approx not in self.hessian_approx_methods:
            raise ValueError("Invalid hess_approx parameter: {}".format(beta2))

        defaults = dict(
            lr=lr,
            mc_samples=mc_samples,
            beta1=beta1,
            beta2=beta2,
            prior_mean=prior_mean,
            prior_prec=prior_prec,
            dual_mean=dual_mean,
            dual_prec=dual_prec,
            hess_init=hess_init,
            ess=ess,
            clip_radius=clip_radius,
            local_wd=(local_wd/ess)
        )
        super().__init__(params, defaults)

        self.mc_samples = mc_samples
        self.hess_approx = hess_approx
        self.sync = sync
        self._numel, self._device, self._dtype = self._get_param_configs()
        self.current_step = 0
        self.debias = debias
        self.rescale_lr = rescale_lr
        self.local_wd = (local_wd/ess)
        self.prior_mean = prior_mean 
        self.prior_prec = prior_prec 

        # set initial temporary running averages
        self._reset_samples()
        # init all states
        self._init_buffers()
        self._set_prior(prior_mean, prior_prec)

    def _set_prior(self, prior_mean, prior_prec):
        p_offset = 0
        g_offset = 0
        for group in self.param_groups:
            g_slice = slice(g_offset, g_offset + group['numel'])
            group["weight_decay"] = prior_prec[g_slice]/group['ess']
            if prior_mean is None: 
                group["prior_mean"] = None
            else: 
                group["prior_mean"] = prior_mean[g_slice]
            g_offset += group['numel']
        assert g_offset == self._numel  # sanity check

    def _set_dual(self, dual_mean, dual_prec):
        p_offset = 0
        g_offset = 0
        for group in self.param_groups:
            g_slice = slice(g_offset, g_offset + group['numel'])
            group["dual_prec"] = dual_prec[g_slice]
            group["dual_mean"] = dual_mean[g_slice]
            g_offset += group['numel']
        assert g_offset == self._numel  # sanity check

    def _set_posterior(self, posterior_mean, posterior_prec):
        for group, m, s in zip(self.param_groups, posterior_mean, posterior_prec): 
            # set parameters
            offset = 0
            for p in group["params"]:
                if p is None:
                    continue

                p_slice = slice(offset, offset + p.numel())
                p.data = m[p_slice].view(p.shape)
                offset += p.numel()
            assert offset == self._numel  # sanity check

            # set hessian
            group['hess'] = s / group['ess'] - group['weight_decay'] - self.local_wd

    def _set_posterior_simple(self, posterior_mean, posterior_prec):
        for group in self.param_groups:
            # set parameters
            offset = 0
            for p in group["params"]:
                if p is None:
                    continue

                p_slice = slice(offset, offset + p.numel())
                p.data = posterior_mean[p_slice].view(p.shape)
                offset += p.numel()
            assert offset == self._numel  # sanity check

            # set hessian
            group['hess'] = (posterior_prec / group['ess']) - group['weight_decay'] - self.local_wd
    
    def _get_posterior(self):
        mean = []
        prec = [] 

        for group in self.param_groups: 
            mean.append(torch.cat([p.flatten() for p in group['params'] if p is not None]))
            prec.append(group['ess'] * (group['hess'] + group['weight_decay'] + self.local_wd))

        return mean, prec

    def _get_param_configs(self):
        all_params = []
        for pg in self.param_groups:
            pg["numel"] = sum(p.numel() for p in pg["params"] if p is not None)
            all_params += [p for p in pg["params"] if p is not None]
        if len(all_params) == 0:
            return 0, torch.device("cpu"), torch.get_default_dtype()
        devices = {p.device for p in all_params}
        if len(devices) > 1:
            raise ValueError(
                "Parameters are on different devices: "
                f"{[str(d) for d in devices]}"
            )
        device = next(iter(devices))
        dtypes = {p.dtype for p in all_params}
        if len(dtypes) > 1:
            raise ValueError(
                "Parameters are on different dtypes: "
                f"{[str(d) for d in dtypes]}"
            )
        dtype = next(iter(dtypes))
        total = sum(pg["numel"] for pg in self.param_groups)
        return total, device, dtype

    def _reset_samples(self):
        self.state['count'] = 0
        self.state['avg_grad'] = None
        self.state['avg_nxg'] = None
        self.state['avg_gsq'] = None

    def _init_buffers(self):
        for group in self.param_groups:
            hess_init, numel = group["hess_init"], group["numel"]

            group["momentum"] = torch.zeros(
                numel, device=self._device, dtype=self._dtype
            )
            group["hess"] = torch.zeros(
                numel, device=self._device, dtype=self._dtype
            ).add(torch.as_tensor(hess_init))

    @contextmanager
    def sampled_params(self, train: bool = False):
        param_avg, noise = self._sample_params()
        yield
        self._restore_param_average(train, param_avg, noise)

    def _restore_param_average(
        self, train: bool, param_avg: Tensor, noise: Tensor
    ):
        param_grads = []
        offset = 0
        for group in self.param_groups:
            for p in group["params"]:
                if p is None:
                    continue

                p_slice = slice(offset, offset + p.numel())

                p.data = param_avg[p_slice].view(p.shape)
                if train:
                    if p.requires_grad:
                        param_grads.append(p.grad.flatten())
                    else:
                        param_grads.append(torch.zeros_like(p).flatten())
                offset += p.numel()
        assert offset == self._numel  # sanity check

        if train:  # collect grad sample for training
            grad_sample = torch.cat(param_grads, 0)
            count = self.state["count"] + 1
            self.state["count"] = count
            self.state["avg_grad"] = _welford_mean(
                self.state["avg_grad"], grad_sample, count
            )
            if self.hess_approx == 'price':
                self.state['avg_nxg'] = _welford_mean(
                    self.state['avg_nxg'], noise * grad_sample, count)
            elif self.hess_approx == 'gradsq':
                self.state['avg_gsq'] = _welford_mean(
                    self.state['avg_gsq'], grad_sample.square(), count)

    @torch.no_grad()
    def step(self, closure: ClosureType = None) -> Optional[Tensor]:
        if closure is None:
            loss = None
        else:
            losses = []
            for _ in range(self.mc_samples):
                with torch.enable_grad():
                    loss = closure()
                losses.append(loss)
            loss = sum(losses) / self.mc_samples
        if self.sync and dist.is_initialized():  # explicit sync
            self._sync_samples()
        self._update()
        self._reset_samples()
        return loss

    def _sync_samples(self):
        world_size = dist.get_world_size()
        dist.all_reduce(self.state["avg_grad"])
        self.state["avg_grad"].div_(world_size)
        dist.all_reduce(self.state["avg_nxg"])
        self.state["avg_nxg"].div_(world_size)

    def _sample_params(self) -> Tuple[Tensor, Tensor]:
        noise_samples = []
        param_avgs = []

        offset = 0
        for group in self.param_groups:
            gnumel = group["numel"]
            noise_sample = (
                torch.randn(gnumel, device=self._device, dtype=self._dtype)
                / (
                    group["ess"] * (group["hess"] + group["weight_decay"] + self.local_wd)
                ).sqrt()
            )
            noise_samples.append(noise_sample)

            goffset = 0
            for p in group["params"]:
                if p is None:
                    continue

                p_avg = p.data.flatten()
                numel = p.numel()
                p_noise = noise_sample[goffset : goffset + numel]

                param_avgs.append(p_avg)
                p.data = (p_avg + p_noise).view(p.shape)
                goffset += numel
                offset += numel
            assert goffset == group["numel"]  # sanity check
        assert offset == self._numel  # sanity check

        return torch.cat(param_avgs, 0), torch.cat(noise_samples, 0)

    def _update(self):
        self.current_step += 1

        offset = 0
        for group in self.param_groups:
            lr = group["lr"]
            b1 = group["beta1"]
            b2 = group["beta2"]
            pg_slice = slice(offset, offset + group["numel"])

            param_avg = torch.cat(
                [p.flatten() for p in group["params"] if p is not None], 0
            )

            group["momentum"] = self._new_momentum(
                self.state["avg_grad"][pg_slice], group["momentum"], b1
            )

            group["hess"] = self._new_hess(
                self.hess_approx,
                group["hess"],
                self.state["avg_nxg"],
                self.state['avg_gsq'],
                pg_slice,
                group["ess"],
                b2,
                group["weight_decay"],
                group["dual_prec"],
                self.local_wd
            )

            param_avg = self._new_param_averages(
                param_avg,
                group["prior_mean"],
                group["hess"],
                group["momentum"],
                lr * (group["hess_init"] + group["weight_decay"]) if self.rescale_lr else lr,
                group["weight_decay"],
                group["clip_radius"],
                1.0 - pow(b1, float(self.current_step)) if self.debias else 1.0,
                group["hess_init"],
                group["dual_mean"],
                group["dual_prec"],
                self.local_wd
            )

            # update params
            pg_offset = 0
            for p in group["params"]:
                if p is not None:
                    p.data = param_avg[pg_offset : pg_offset + p.numel()].view(
                        p.shape
                    )
                    pg_offset += p.numel()
            assert pg_offset == group["numel"]  # sanity check
            offset += group["numel"]
        assert offset == self._numel  # sanity check

    @staticmethod
    def _get_nll_hess(method: str, hess, avg_nxg, avg_gsq, pg_slice) -> Tensor:
        if method == 'price':
            return avg_nxg[pg_slice] * hess
        elif method == 'gradsq':
            return avg_gsq[pg_slice]
        else:
            raise NotImplementedError(f'unknown hessian approx.: {method}')

    @staticmethod
    def _new_momentum(avg_grad, m, b1) -> Tensor:
        return b1 * m + (1.0 - b1) * avg_grad

    @staticmethod
    def _new_hess(
        method, hess, avg_nxg, avg_gsq, pg_slice, ess, beta2, wd, dp, loc_wd
    ) -> Tensor:
        f = IVON._get_nll_hess(
            method, hess + wd + loc_wd, avg_nxg, avg_gsq, pg_slice
        ) * ess

        # f is the reparametrization-trick Hessian (\widehat h in the pseudo-code)
        # for ADMM, we correct by Hessian of 1/2 \theta^T V \theta from Lagrange mutliplier
        # this is just V. 
        
        f += dp 

        return beta2 * hess + (1.0 - beta2) * f + \
            (0.5 * (1 - beta2) ** 2) * (hess - f).square() / (hess + wd + loc_wd)

    @staticmethod
    def _new_param_averages(
        param_avg, prior_mean, hess, momentum, lr, wd, clip_radius, debias, hess_init, dm, dp, loc_wd
    ) -> Tensor:
        # we add to the stochastic gradient also the gradient of the lagrange multiplier term
        # in our case, this is -v+Vm 
        lagr_grad = -dm + dp*param_avg  

        if prior_mean is None: 
            return param_avg - lr * torch.clip(
                (momentum / debias + lagr_grad + wd * param_avg + loc_wd * param_avg) / (hess + wd + loc_wd),
                min=-clip_radius,
                max=clip_radius,
            )
        else: 
            return param_avg - lr * torch.clip(
                (momentum / debias + lagr_grad + wd * (param_avg - prior_mean) + loc_wd * param_avg) / (hess + wd + loc_wd),
                min=-clip_radius,
                max=clip_radius,
            )

