"""Stochastic LBFGS optmizer with Damping
DESC:
    A naive distributed version of Stochastic LBFGS implementation.
    The main issue in distributing LBFGS is how to store/access two history vars:
    self.hist_dg and self.hist_dp.
    In the naive implementation, we split these two vars and assigned to four GPUS.
    When accessing the vars, we need to explicitely move them from one GPU to another.

AUTHOR:
    

NOTE:
"""
import torch
import torch.optim as optim
import math

cuda_list = [torch.device(i) for i in range(torch.cuda.device_count())]
epsilon = 0.00001


class LBFGSOptimizer(optim.Optimizer):
    def __init__(
        self,
        model_parameters,
        lr=1,
        momentum=0.9,
        weight_decay=0.0,
        rho_min=0.00001,
        mm_p=0.9,
        mm_g=0.99,
        update_freq=100,
        hist_sz=100,
        damping=0.2,
        kl_clip=0.005,
        kl_clip_fix_scaling=False,
    ):
        """
        Args:
            model_parameters: DNN model parameters
            lr: learning rate (default: 1)
            momentum: momentum for averaging update
            weight_decay: weight decay
            rho_min: a threshold to decide whether to store history vector
            mm_p: momentum for averaging history param
            mm_g: momentum for averaging hitory grad
            update_freq: frequency for updating Hessian Inverse
            hist_sz: size of history vectors
        """
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
        super(LBFGSOptimizer, self).__init__(model_parameters, defaults)

        # Initialize essential variables
        self.hist_dg = []
        self.hist_dp = []
        self.hist_mdg = []  # momuntum of dg
        self.rho_list = []

        # for debug purpose
        self.tao_before = 1.0
        self.tao_after = 1.0

        # use avg_var to calculcate self.history_delta_var
        self.avg_p = []
        self.hist_avg_p = []
        self.avg_g = []
        self.hist_avg_g = []
        self.has_avg_p = False
        self.has_avg_g = False
        self.has_hist_p = False
        self.has_hist_g = False

        self.snapshot_p = []
        self.snapshot_g = []

        self.model_param = self.param_groups[0]["params"]
        print("[LBFGS] number of param groups: {}".format(len(self.param_groups)))

        # optimizer hyper parameters
        self.rho_min = rho_min
        self.mm_p = mm_p
        self.mm_g = mm_g
        self.base_update_freq = update_freq
        self.hist_sz = hist_sz
        self.actual_update_freq = update_freq
        self.enable_damping = True
        self.damping = damping
        self.damp_low = damping
        self.damp_high = 2.0
        self.kl_clip = kl_clip
        self.kl_clip_fix_scaling = kl_clip_fix_scaling

        # iteration info
        self.init_lr = lr
        self.steps = 0  # batch-wise
        self.update_dg_dp = False
        self.epoch = 0
        self.start_lbfgs = False
        self.h0 = 1.0

        print(
            "[LBFGS] initialize LBFGS optimizer:\n"
            "-------------------------------------\n"
            f"Base Hessian update frequency: {self.base_update_freq}\n"
            f"History vector size: {self.hist_sz}\n"
            f"Enable damping: {self.enable_damping}\n"
            f"Momentum for param: {self.mm_p}\n"
            f"Momentum for grad: {self.mm_g}\n"
            "-------------------------------------"
        )

    def __flattern(self, tensorlist):
        views = []
        for p in tensorlist:
            view = p.view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def __inv_flattern(self, vec, refparam):
        offset = 0
        views = []
        for p in refparam:
            if p.grad is None:
                continue
            tmp = vec[offset : offset + p.data.numel()]
            view = tmp.view(p.data.size())
            views.append(view)
            offset += p.data.numel()
        return views

    def __set_param(self, params):
        for p, pdata in zip(self.model_param, params):
            if p.grad is None:
                continue
            p.copy_(pdata)

    def __set_grad(self, grads):
        for p, dp in zip(self.model_param, grads):
            if p.grad is None:
                continue
            p.grad.data.copy_(dp)

    def __clone_param(self):
        return [
            p.data.clone(memory_format=torch.contiguous_format)
            for p in self.model_param
            if p.grad is not None
        ]

    def __clone_grad(self):
        return [
            p.grad.data.clone(memory_format=torch.contiguous_format)
            for p in self.model_param
            if p.grad is not None
        ]

    def __update_avg(self, pdata, pdata_avg, stat_decay):
        pdata_avg *= stat_decay / (1 - stat_decay)
        pdata_avg += pdata
        pdata_avg *= 1 - stat_decay

    def __get_dp(self):
        i = 0
        dp = []
        for p in self.model_param:
            if p.grad is None:
                continue
            if not self.has_avg_p:
                self.avg_p.append(p.data.clone())
            else:
                self.__update_avg(p.data, self.avg_p[i], self.mm_p)
                if self.update_dg_dp:
                    if self.has_hist_p:
                        dp.append(self.avg_p[i] - self.hist_avg_p[i])
                        self.hist_avg_p[i].copy_(self.avg_p[i])
                    else:
                        self.hist_avg_p.append(self.avg_p[i].clone())
            i += 1
        self.has_avg_p = True
        self.has_hist_p = len(self.hist_avg_p) > 0

        # update delta_param
        if len(dp) > 0:
            l = len(self.hist_dp)
            dp_flatten = self.__flattern(dp)
            if l == self.hist_sz:
                dp_old = self.hist_dp.pop(0)
                del dp_old
            self.hist_dp.append(dp_flatten.cuda())

    def __get_dg(self):
        i = 0
        dg = []
        lr = self.param_groups[0]["lr"]
        for p in self.model_param:
            if p.grad is None:
                continue
            g = p.grad.data.clone()
            if not self.has_avg_g:
                self.avg_g.append(g)
            else:
                self.__update_avg(g, self.avg_g[i], self.mm_g)
                if self.update_dg_dp:
                    if self.has_hist_g:
                        dg.append(self.avg_g[i] - self.hist_avg_g[i])
                        self.hist_avg_g[i].copy_(self.avg_g[i])
                    else:
                        self.hist_avg_g.append(self.avg_g[i].clone())
            i += 1
        self.has_avg_g = True
        self.has_hist_g = len(self.hist_avg_g) > 0

        # update dp
        scaling = lr / self.init_lr
        if len(dg) > 0:
            l = len(self.hist_dg)
            dg_flatten = self.__flattern(dg)
            dg_flatten.mul_(scaling)
            # damp dg
            s = self.hist_dp[-1]
            y = dg_flatten
            v = self.hist_dp[-1]
            self.tao_before = torch.dot(s, y) / (torch.dot(s, v) + epsilon)
            if self.enable_damping:
                if self.tao_before < self.damp_low:
                    phi = (1 - self.damp_low) / (1 - self.tao_before)
                elif self.tao_before > self.damp_high:
                    phi = (self.damp_high - 1) / (self.tao_before - 1)
                else:
                    phi = 1.0
                phi = min(phi, 1 - self.damp_low)
                dg_flatten.mul_(phi).add_(v, alpha=1 - phi)
            self.tao_after = torch.dot(s, dg_flatten) / (torch.dot(s, v) + epsilon)

            if l == self.hist_sz:
                dg_old = self.hist_dg.pop(0)
                del dg_old
                rho_old = self.rho_list.pop(0)
                del rho_old

            self.hist_dg.append(dg_flatten.cuda())

            self.__get_rho()

    def __get_mdg(self):
        """recursive compute B_k * s_k"""
        assert len(self.hist_mdg) == len(self.hist_dp) - 1

        l = len(self.hist_mdg)
        s = self.hist_dp[-1]
        B0 = 1  # B is initialized as identity matrix
        v = B0 * s
        if l == 0:
            self.hist_mdg.append(v.cuda())
            return
        for i in range(0, l):
            v_i = self.hist_mdg[i]
            s_i = self.hist_dp[i]
            c1 = torch.dot(v_i, s) / (torch.dot(v_i, s_i) + epsilon)
            v.add_(v_i, alpha=-c1)

            y_i = self.hist_dg[i]
            c2 = torch.dot(y_i, s) / (torch.dot(y_i, s_i) + epsilon)
            v.add_(y_i, alpha=c2)
        self.hist_mdg.append(v.cuda())

    def __get_rho(self):
        assert len(self.hist_dp) == len(self.hist_dg)
        rho = torch.dot(self.hist_dp[-1], self.hist_dg[-1])
        if rho < self.rho_min:
            dp_bad = self.hist_dp.pop(-1)
            del dp_bad
            dg_bad = self.hist_dg.pop(-1)
            del dg_bad
            return

        self.rho_list.append(rho)
        self.start_lbfgs = True
        self.h0 = self.rho_list[-1] / (
            torch.dot(self.hist_dg[-1], self.hist_dg[-1]) + epsilon
        )

    def __update_gradient(self):
        l = len(self.hist_dp)
        assert l == len(self.hist_dg)
        assert l == len(self.rho_list)
        wd = self.param_groups[0]["weight_decay"]
        g_flat = self.__flattern(self.snapshot_g)
        p_flat = self.__flattern(self.snapshot_p)
        g = torch.add(g_flat, p_flat, alpha=wd)

        # conduct Hessian-gradient product
        alpha_list = []
        for i in range(0, l):
            alpha = torch.dot(self.hist_dp[l - 1 - i], g) / (
                self.rho_list[l - 1 - i] + epsilon
            )
            alpha_list.append(alpha)
            g.add_(self.hist_dg[l - 1 - i], alpha=-alpha)
        g.mul_(self.h0)
        for i in range(l, 0, -1):
            beta = torch.dot(self.hist_dg[l - i], g) / (self.rho_list[l - i] + epsilon)
            g.add_(self.hist_dp[l - i], alpha=alpha_list[i - 1] - beta)

        # copy conditioned gradient and apply kl clip
        if self.kl_clip_fix_scaling:
            vg_sum = torch.dot(g_flat, g)
        else:
            lr = self.param_groups[0]["lr"]
            vg_sum = torch.dot(g_flat, g) * lr ** 2

        nu = min(1.0, math.sqrt(self.kl_clip / vg_sum))
        g_shaped = self.__inv_flattern(g, self.model_param)
        for p, g_p in zip(self.model_param, g_shaped):
            if p.grad is None:
                continue
            p.grad.data.copy_(g_p)
            p.grad.data.mul_(nu)

    @torch.no_grad()
    def step(self, closure=None, epoch=0, batch=0):
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        self.actual_update_freq = self.base_update_freq

        self.steps += 1
        self.epoch = epoch
        self.update_dg_dp = self.steps % self.actual_update_freq == 0
        # save gradients and param for gradient update
        self.snapshot_g = self.__clone_grad()
        self.snapshot_p = self.__clone_param()

        self.__get_dp()
        self.__get_dg()

        if self.start_lbfgs:
            self.__update_gradient()

        for group in self.param_groups:
            momentum = group["momentum"]
            wd = group["weight_decay"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                dp = p.grad.data
                # warm-up stage (SGD)
                if not self.start_lbfgs:
                    dp.add_(p.data, alpha=wd)
                if momentum != 0:
                    param_state = self.state[p]
                    if "momentum_buf" not in param_state:
                        buf = param_state["momentum_buf"] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(dp)
                    else:
                        buf = param_state["momentum_buf"]
                        buf.mul_(momentum).add_(1.0, dp)
                    dp = buf
                p.data.add_(-group["lr"], dp)
