import copy
from math import pow
from typing import Callable, Optional, Tuple
from contextlib import contextmanager
import numpy as np
import torch
import torch.optim
import torch.distributed as dist
from torch import Tensor


ClosureType = Callable[[], Tensor]


def _welford_mean(avg: Optional[Tensor], newval: Tensor, count: int) -> Tensor:
    return newval if avg is None else avg + (newval - avg) / count


class IVONPCM(torch.optim.Optimizer):
    hessian_approx_methods = (
        'price',
        'gradsq',
    )

    def __init__(
        self,
        params,
        lr: float,
        ess: float,
        hess_init: float = 1.0,
        beta1: float = 0.9,
        beta2: float = 0.99999,
        weight_decay: float = 1e-4,
        mc_samples: int = 1,
        hess_approx: str = 'price',
        clip_radius: float = float("inf"),
        sync: bool = False,
        debias: bool = True,
        rescale_lr: bool = True,
        alpha: float = 1.0,
        h_term_weight: float = 1.0,
        rho1: float = 1.0,
        rho2: float = 1.0
    ):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 1 <= mc_samples:
            raise ValueError(
                "Invalid number of MC samples: {}".format(mc_samples)
            )
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight decay: {}".format(weight_decay))
        if not 0.0 < hess_init:
            raise ValueError(
                "Invalid Hessian initialization: {}".format(hess_init)
            )
        if not 0.0 < ess:
            raise ValueError("Invalid effective sample size: {}".format(ess))
        if not 0.0 < clip_radius:
            raise ValueError("Invalid clipping radius: {}".format(clip_radius))
        if not 0.0 <= beta1 <= 1.0:
            raise ValueError("Invalid beta1 parameter: {}".format(beta1))
        if not 0.0 <= beta2 <= 1.0:
            raise ValueError("Invalid beta2 parameter: {}".format(beta2))
        if hess_approx not in self.hessian_approx_methods:
            raise ValueError("Invalid hess_approx parameter: {}".format(beta2))

        defaults = dict(
            lr=lr,
            mc_samples=mc_samples,
            beta1=beta1,
            beta2=beta2,
            weight_decay=weight_decay,
            hess_init=hess_init,
            ess=ess,
            clip_radius=clip_radius,
        )
        super().__init__(params, defaults)

        self.mc_samples = mc_samples
        self.hess_approx = hess_approx
        self.sync = sync
        self._numel, self._device, self._dtype = self._get_param_configs()
        self.current_step = 0
        self.debias = debias
        self.rescale_lr = rescale_lr
        self.alpha = alpha
        self.htermweight = h_term_weight
        self.rho1 = rho1
        self.rho2 = rho2

        # set initial temporary running averages
        self._reset_samples()
        # init all states
        self._init_buffers()

    def _get_param_configs(self):
        all_params = []
        for pg in self.param_groups:
            pg["numel"] = sum(p.numel() for p in pg["params"] if p is not None)
            all_params += [p for p in pg["params"] if p is not None]
        if len(all_params) == 0:
            return 0, torch.device("cpu"), torch.get_default_dtype()
        devices = {p.device for p in all_params}
        if len(devices) > 1:
            raise ValueError(
                "Parameters are on different devices: "
                f"{[str(d) for d in devices]}"
            )
        device = next(iter(devices))
        dtypes = {p.dtype for p in all_params}
        if len(dtypes) > 1:
            raise ValueError(
                "Parameters are on different dtypes: "
                f"{[str(d) for d in dtypes]}"
            )
        dtype = next(iter(dtypes))
        total = sum(pg["numel"] for pg in self.param_groups)
        return total, device, dtype

    def _reset_samples(self):
        self.state['count'] = 0
        self.state['avg_grad'] = None
        self.state['avg_nxg'] = None
        self.state['avg_gsq'] = None

    def _init_buffers(self):
        for group in self.param_groups:
            hess_init, numel = group["hess_init"], group["numel"]

            group["momentum"] = torch.zeros(
                numel, device=self._device, dtype=self._dtype
            ) # this shouldn't be needed for the old parameters!
            group["hess"] = torch.zeros(
                numel, device=self._device, dtype=self._dtype
            ).add(torch.as_tensor(hess_init))

    @contextmanager
    def sampled_params(self, train: bool = False):
        param_avg, noise = self._sample_params()
        yield
        self._restore_param_average(train, param_avg, noise)

    def _restore_param_average(
        self, train: bool, param_avg: Tensor, noise: Tensor
    ):
        param_grads = []
        offset = 0
        for group in self.param_groups:
            for p in group["params"]:
                if p is None:
                    continue

                p_slice = slice(offset, offset + p.numel())

                p.data = param_avg[p_slice].view(p.shape)
                if train:
                    if p.requires_grad and p.grad is not None: 
                        param_grads.append(p.grad.flatten())
                    else:
                        param_grads.append(torch.zeros_like(p).flatten())
                offset += p.numel()
        assert offset == self._numel  # sanity check

        if train:  # collect grad sample for training
            grad_sample = torch.cat(param_grads, 0)
            count = self.state["count"] + 1
            self.state["count"] = count
            self.state["avg_grad"] = _welford_mean(
                self.state["avg_grad"], grad_sample, count
            )
            if self.hess_approx == 'price':
                self.state['avg_nxg'] = _welford_mean(
                    self.state['avg_nxg'], noise * grad_sample, count)
            elif self.hess_approx == 'gradsq':
                self.state['avg_gsq'] = _welford_mean(
                    self.state['avg_gsq'], grad_sample.square(), count)

    @torch.no_grad()
    def step(self, closure: ClosureType = None, use_sarah=False, sarah_factor=1) -> Optional[Tensor]:
        if closure is None:
            loss = None
        else:
            losses = []
            for _ in range(self.mc_samples):
                with torch.enable_grad():
                    loss = closure()
                losses.append(loss)
            loss = sum(losses) / self.mc_samples
        if self.sync and dist.is_initialized():  # explicit sync
            self._sync_samples()
        self._update(use_sarah=use_sarah, sarah_factor=sarah_factor)
        self._reset_samples()
        return loss

    @torch.no_grad()
    def collect_full_grads(self):
        # self.current_step += 1
        if self.sync and dist.is_initialized():  # explicit sync
            self._sync_samples()

        old_pg_slice = slice(self.param_groups[0]["numel"], self.param_groups[0]["numel"] + self.param_groups[1]["numel"])
        if not isinstance(self.state["old_grads"], torch.Tensor):
            self.state["old_grads"] = copy.deepcopy(self.state["avg_grad"][old_pg_slice])
        else:
            self.state["old_grads"] = (1.-self.rho1) * self.state["old_grads"] + self.rho1 * copy.deepcopy(self.state["avg_grad"][old_pg_slice])
        old_f = IVONPCM._get_nll_hess(
            "price", 
            self.param_groups[1]["hess"] + self.param_groups[1]["weight_decay"], 
            self.state["avg_nxg"], 
            None, 
            old_pg_slice
        ) * self.param_groups[1]["ess"]
        if not isinstance(self.state["old_hess_estimate"], torch.Tensor):
            self.state["old_hess_estimate"] = copy.deepcopy(old_f)
        else:
            self.state["old_hess_estimate"] = (1.-self.rho2) * self.state["old_hess_estimate"] + self.rho2 * copy.deepcopy(old_f)
        self._reset_samples()


    @torch.no_grad()
    def refresh_old_model(self):
        group, old_group = self.param_groups[0], self.param_groups[1]
        for p, p_old in zip(group["params"], old_group["params"]):
            p_old.copy_(p)
        old_group["hess"].copy_(group["hess"]) # hess is kept 1d but params are kept as list of tensors

    def _sync_samples(self):
        world_size = dist.get_world_size()
        dist.all_reduce(self.state["avg_grad"])
        self.state["avg_grad"].div_(world_size)
        dist.all_reduce(self.state["avg_nxg"])
        self.state["avg_nxg"].div_(world_size)

    def _sample_params(self) -> Tuple[Tensor, Tensor]:
        noise_samples = []
        param_avgs = []

        offset = 0
        for group in self.param_groups:
            gnumel = group["numel"]
            noise_sample = (
                torch.randn(gnumel, device=self._device, dtype=self._dtype)
                / (
                    group["ess"] * (group["hess"] + group["weight_decay"])
                ).sqrt()
            )
            noise_samples.append(noise_sample)

            goffset = 0
            for p in group["params"]:
                if p is None:
                    continue

                p_avg = p.data.flatten()
                numel = p.numel()
                p_noise = noise_sample[goffset : goffset + numel]

                param_avgs.append(p_avg)
                p.data = (p_avg + p_noise).view(p.shape)
                goffset += numel
                offset += numel
            assert goffset == group["numel"]  # sanity check
        assert offset == self._numel  # sanity check

        return torch.cat(param_avgs, 0), torch.cat(noise_samples, 0)

    def _update(self, use_sarah=False, sarah_factor=1):
        self.current_step += 1

        offset = 0
        group = self.param_groups[0]
        old_group = self.param_groups[1]

        lr = group["lr"]
        b1 = group["beta1"]
        b2 = group["beta2"]
        pg_slice = slice(offset, offset + group["numel"])
        old_pg_slice = slice(offset + group["numel"], offset + group["numel"] + old_group["numel"])

        param_avg = torch.cat(
            [p.flatten() for p in group["params"] if p is not None], 0
        )

        group["momentum"] = self._new_momentum(
            self.state["avg_grad"][pg_slice], group["momentum"], self.state["avg_grad"][old_pg_slice], self.state["old_grads"], b1, self.alpha
        )

        group["hess"], f_old, f_new = self._new_hess(
            self.hess_approx,
            group["hess"],
            old_group["hess"],
            self.state["old_hess_estimate"],
            self.state["avg_nxg"],
            self.state['avg_gsq'],
            pg_slice,
            old_pg_slice,
            group["ess"],
            b2,
            group["weight_decay"],
            self.alpha
        )

        param_avg = self._new_param_averages(
            param_avg,
            group["hess"],
            group["momentum"],
            lr * (group["hess_init"] + group["weight_decay"]) if self.rescale_lr else lr,
            group["weight_decay"],
            group["clip_radius"],
            1.0 - pow(b1, float(self.current_step)) if self.debias else 1.0,
            group["hess_init"],
            self.state["old_hess_estimate"],
            f_old,
            torch.cat([p.flatten() for p in old_group["params"] if p is not None], 0),
            self.alpha,
            self.htermweight
        )

        # update params
        pg_offset = 0
        for p in group["params"]:
            if p is not None:
                p.data = param_avg[pg_offset : pg_offset + p.numel()].view(
                    p.shape
                )
                pg_offset += p.numel()
        assert pg_offset == group["numel"]  # sanity check
        offset += group["numel"]

    @staticmethod
    def _get_nll_hess(method: str, hess, avg_nxg, avg_gsq, pg_slice) -> Tensor:
        if method == 'price':
            return avg_nxg[pg_slice] * hess
        elif method == 'gradsq':
            return avg_gsq[pg_slice]
        else:
            raise NotImplementedError(f'unknown hessian approx.: {method}')

    @staticmethod
    def _new_momentum(avg_grad, m, avg_grad_old, avg_grad_old_full, b1, alpha) -> Tensor:
        if not isinstance(avg_grad_old_full, dict):
            return b1 * m + (1.0 - b1) * (avg_grad - alpha * (avg_grad_old - avg_grad_old_full))
        else:
            return b1 * m + (1.0 - b1) * avg_grad

    @staticmethod
    def _new_hess(
        method, hess, old_hess, old_hess_estimate, avg_nxg, avg_gsq, pg_slice, old_pg_slice, ess, beta2, wd, alpha
    ) -> Tensor:
        f = IVONPCM._get_nll_hess(
            method, hess + wd, avg_nxg, avg_gsq, pg_slice
        ) * ess
        if not isinstance(old_hess_estimate, dict):
            f_old = IVONPCM._get_nll_hess(
                method, old_hess + wd, avg_nxg, avg_gsq, old_pg_slice
            ) * ess
            f = f - alpha * (f_old - old_hess_estimate)
            return beta2 * hess + (1.0 - beta2) * f + \
                (0.5 * (1 - beta2) ** 2) * (hess - f).square() / (hess + wd), f_old, f
        else:
            return beta2 * hess + (1.0 - beta2) * f + \
                (0.5 * (1 - beta2) ** 2) * (hess - f).square() / (hess + wd), None, None

    @staticmethod
    def _new_param_averages(
        param_avg, hess, momentum, lr, wd, clip_radius, debias, hess_init, h_out, f_old, m_out, alpha, h_factor
    ) -> Tensor:
        if f_old is not None:
            return param_avg - lr * torch.clip(
                (momentum / debias + wd * param_avg + alpha * h_factor * (h_out - f_old) * (param_avg - m_out)) / (hess + wd),
                min=-clip_radius,
                max=clip_radius,
            )
        else:
            return param_avg - lr * torch.clip(
                (momentum / debias + wd * param_avg) / (hess + wd),
                min=-clip_radius,
                max=clip_radius,
            )
