import torch
from functools import reduce
from .optimizer import Optimizer

import numpy as np
#from scipy.linalg import eigvalsh_tridiagonal


def dPinv(d):
    return 0. if abs(d) <= 1e-16 else 1. / d


class RSTBFGS(Optimizer):

    def __init__(self, optimizer=None, alpha=1.0, beta=1.0,
                 damp1=1, damp2=0.01, mu=0, precision=1):
        if optimizer is None:
            raise ValueError("optimizer cannot be None")
        if alpha < 0.0:
            raise ValueError("Invalid RSTBFGS alpha parameter: {}".format(alpha))
        if beta < 0.0:
            raise ValueError("Invalid RSTBFGS beta parameter: {}".format(beta))

        self.optimizer = optimizer
        self.beta = beta
        self.alpha = alpha
        self.damp1 = damp1
        self.damp2 = damp2
        self.mu = mu
        self.precision = precision
        self.param_groups = self.optimizer.param_groups
        self.state = self.optimizer.state
        self.defaults = self.optimizer.defaults
        #self.eig = None
        #self.eignum = eignum

        if len(self.param_groups) != 1:
            raise ValueError("Conjugate Anderson doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None

        N = self._numel()
        device = self._params[0].device
        if self.precision == 0:
            dtype = self._params[0].dtype
        else:
            dtype = torch.float64
        state = self.state
        state.setdefault('step', 0)
        state.setdefault('p', torch.zeros(N, device=device, dtype=dtype))
        state.setdefault('q', torch.zeros(N, device=device, dtype=dtype))
        state.setdefault('x_prev', torch.zeros(N, dtype=dtype, device=device))
        state.setdefault('r_prev', torch.zeros(N, dtype=dtype, device=device))
        state['pq'] = 0.

    def __setstate__(self, state):
        super(RSTBFGS, self).__setstate__(state)

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.new(p.numel()).zero_()
            elif p.grad.is_sparse:
                view = p.grad.to_dense().view(-1)
            else:
                view = p.grad.view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def _gather_flat_data(self):
        views = []
        for p in self._params:
            views.append(p.data.view(-1))
        return torch.cat(views, 0)

    def _store_data(self, other):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.copy_(other[offset:offset + numel].view_as(p))
            offset += numel
        assert offset == self._numel()

    def _store_grad(self, other):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.grad.copy_(other[offset:offset + numel].view_as(p))
            offset += numel
        assert offset == self._numel()

    def _add_grad(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
            offset += numel
        assert offset == self._numel()

    def _directional_evaluate(self, closure, x, g, t, d):
        self._add_grad(t, d)
        loss = closure()
        xk = self._gather_flat_data()
        flat_grad = self._gather_flat_grad()
        self._store_data(x)
        self._store_grad(g)
        return loss, xk, flat_grad

    def setfullgrad(self, length):
        self.fullgrad = self._gather_flat_grad().div(length)

    def settmpx(self):
        self.xk = self._gather_flat_data()

    def _get_x_delta(self, Xk, Rk, delta):
        Q, R, G = simpleQR(Rk)
        Xk = Xk.mm(G)
        Rk = Rk.mm(G)
        H_inv = inv(R + delta * inv(R).t() @ (Xk.t() @ Xk))
        Gamma = H_inv @ (Q.t() @ res)
        x_delta = beta * res - (alpha * Xk + alpha * beta * Rk) @ Gamma
        return x_delta

    # def geteig(self):
    #     return self.eig

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        assert len(self.param_groups) == 1

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        beta = self.beta
        alpha = self.alpha
        damp1 = self.damp1
        damp2 = self.damp2
        mu = self.mu

        optimizer = self.optimizer
        N = self._numel()
        device = self._params[0].device
        if self.precision == 1:
            dtype = torch.float64
        else:
            dtype = self._params[0].dtype

        state = self.state

        xk = self._gather_flat_data().to(dtype)
        flat_grad = self._gather_flat_grad()
        weight_decay = group['weight_decay']
        rk = flat_grad.add(alpha=weight_decay, other=xk).neg().to(dtype)

        p, q = state['p'], state['q']
        r_prev, x_prev = state['r_prev'], state['x_prev']

        cnt = state['step']

        eps = 1e-8

        if cnt == 0:
            p.zero_()
            q.zero_()
        else:
            delta_xk = xk - x_prev
            delta_rk = rk - r_prev

            delta1 = damp1
            xi = torch.dot(p,delta_rk) * dPinv(torch.dot(p,q)-delta1*(torch.dot(p,p)+torch.dot(q,q)))

            p.copy_(delta_xk-p*xi)
            q.copy_(delta_rk-q*xi)

        x_prev.copy_(xk)
        r_prev.copy_(rk)

        delta2 = damp2
        rho = dPinv(torch.dot(p,q)-delta2*(torch.dot(p,p)+torch.dot(q,q)))

        if rho > 0:
            print('Note2: violate positive check, rho: {:e} '.format(float(rho)))
            rho = 0.

        tmp1 = p*(torch.dot(p,rk)*rho)

        tmp2 = rk - q*(torch.dot(p,rk)*rho)
        tmp2 -= p*(torch.dot(q,tmp2)*rho)
        tmp2 -= rk

        xk += beta*rk - alpha*tmp1 + (alpha*beta)*tmp2

        self._store_data(xk)
        cnt += 1
        state['step'] = cnt

        return loss
