import torch
from functools import reduce
from .optimizer import Optimizer
#from torch import inverse as inv
#from torch import pinverse as pinv

def mysvd(A):
    a,b,c=A[0,0],A[1,1],A[0,1]
    
    if c == 0:
        return A.diag(),torch.eye(2,device=A.device,dtype=A.dtype)

    Delta = ((a-b)**2+4*(c**2))**0.5
    s = torch.tensor([(a+b+Delta)/2,(a+b-Delta)/2],device=A.device,dtype=A.dtype)
    V = A.clone()
    #V[0,0],V[1,0],V[0,1],V[1,1] = Delta+a-b,2*c,Delta+b-a,-2*c
    v00,v10,v01,v11 = Delta+a-b,2*c,Delta+b-a,-2*c
    n1 = (v00**2+v10**2)**0.5
    n2 = (v01**2+v11**2)**0.5
    V[0,0],V[1,0],V[0,1],V[1,1]=v00/n1,v10/n1,v01/n2,v11/n2
    #V[:,1]/=V[:,1].norm()
    return s,V

def pinv(A,rcond=1e-15):
    s,V=mysvd(A.double())
    s[0], s[1] = (1/s[0] if s[0]>1e-15 else 0), (1/s[1] if s[1]>1e-15 else 0)
    return ((V*s)@V.t()).to(A.dtype)

def lambda_k(Yk,Zk,Rk):
    YR = torch.cat((Yk,Rk),1)
    Z_inv = pinv(Zk)
    m = Z_inv.size(1)
    G = torch.zeros(2*m,2*m,dtype=Zk.dtype,device=Zk.device)
    G[0:m,m:2*m].copy_(Z_inv)
    G[m:2*m,0:m].copy_(Z_inv)
    F = (YR.t()@YR)@G
    eigs,_ = torch.eig(F)
    t = eigs[:,0]
    return max(t)

class CA(Optimizer):

    def __init__(self, optimizer=None,alpha=1.0,beta=1.0,
                 damp1=1,damp2=0.01,mu=0,precision=0):
        if optimizer is None:
            raise ValueError("optimizer cannot be None")
        if alpha < 0.0:
            raise ValueError("Invalid Conjugate Anderson alpha parameter: {}".format(alpha))
        if beta < 0.0:
            raise ValueError("Invalid Conjugate Anderson beta parameter: {}".format(beta))        

        self.optimizer = optimizer
        self.beta = beta
        self.alpha = alpha
        self.damp1 = damp1
        self.damp2 = damp2
        self.mu = mu
        self.precision = precision
        self.param_groups = self.optimizer.param_groups
        self.state = self.optimizer.state
        self.defaults = self.optimizer.defaults

        if len(self.param_groups) != 1:
            raise ValueError("Conjugate Anderson doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None

        N = self._numel()
        device = self._params[0].device
        if self.precision == 0:
            dtype = self._params[0].dtype
        else:
            dtype = torch.float64   
        state = self.state  
        state.setdefault('step', 0)
        state.setdefault('P', torch.zeros((N, 2), device=device, dtype=dtype))
        state.setdefault('Q', torch.zeros((N, 2), device=device, dtype=dtype))
        state.setdefault('x_prev', torch.zeros(N, dtype=dtype, device=device))
        state.setdefault('res_prev', torch.zeros(N, dtype=dtype, device=device))
        #state.setdefault('x_back', torch.zeros(N, dtype=dtype, device=device))

    def __setstate__(self, state):
        super(CA, self).__setstate__(state)

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.new(p.numel()).zero_()
            elif p.grad.is_sparse:
                view = p.grad.to_dense().view(-1)
            else:
                view = p.grad.view(-1)
            views.append(view)
        return torch.cat(views, 0)

    def _gather_flat_data(self):
        views = []
        for p in self._params:
            views.append(p.data.view(-1))
        return torch.cat(views,0)

    def _store_data(self,other):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.copy_(other[offset:offset + numel].view_as(p))
            offset += numel
        assert offset == self._numel()

    def _store_grad(self,other):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.grad.copy_(other[offset:offset + numel].view_as(p))
            offset += numel
        assert offset == self._numel()

    def _add_grad(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
            offset += numel
        assert offset == self._numel()

    def _directional_evaluate(self, closure, x,g, t, d):
        self._add_grad(t, d)
        loss = closure()
        xk = self._gather_flat_data()
        flat_grad = self._gather_flat_grad()
        self._store_data(x)
        self._store_grad(g)
        return loss, xk, flat_grad

    def setfullgrad(self,length):
        self.fullgrad = self._gather_flat_grad().div(length)

    def settmpx(self):
        self.xk = self._gather_flat_data()

    def _get_x_delta(self,Xk,Rk,delta):
        Q, R, G = simpleQR(Rk)
        Xk = Xk.mm(G)
        Rk = Rk.mm(G)
        H_inv = inv(R + delta * inv(R).t() @ (Xk.t() @ Xk))
        Gamma = H_inv @ (Q.t() @ res)
        x_delta = beta * res - (alpha * Xk + alpha * beta * Rk) @ Gamma
        return x_delta

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        assert len(self.param_groups) == 1

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        beta = self.beta
        alpha = self.alpha
        damp1 = self.damp1
        damp2 = self.damp2
        mu = self.mu
        optimizer=self.optimizer
        N = self._numel()
        device = self._params[0].device
        if self.precision == 1:
            dtype = torch.float64
        else:
            dtype = self._params[0].dtype

        state = self.state

        xk = self._gather_flat_data().to(dtype)
        flat_grad = self._gather_flat_grad()
        weight_decay = group['weight_decay']
        res = flat_grad.add(alpha=weight_decay,other=xk).neg().to(dtype)

        P, Q = state['P'], state['Q']
        res_prev, x_prev = state['res_prev'], state['x_prev']

        cnt = state['step']

        eps = 1e-8
        #if cnt % period == 0:
        if cnt == 0:
            P.zero_()
            Q.zero_()
        else:
            k = cnt % 2
            p = xk-x_prev
            q = res-res_prev

            delta1 = damp1*torch.dot(res,res)/(torch.dot(p,p)+eps)
            xi = -pinv(Q.t() @ Q + delta1 * P.t() @ P) @ (Q.t()@q)
            q += Q@xi
            p += P@xi

            P[:,k].copy_(p)
            Q[:,k].copy_(q)
            sq_p = torch.dot(p,p)

        x_prev.copy_(xk)
        res_prev.copy_(res)

        #if cnt % period == 0:            
        if cnt == 0:
            optimizer.step(None)
        else:
            eps = 1e-8
            delta2 = damp2 * torch.dot(res,res)/(sq_p+eps)
            #eye = torch.eye(2,2,dtype=dtype,device=device)

            '''Yk = P + beta * Q
            Zk = Q.t() @ Q + delta2 * P.t() @ P
            lambdak = lambda_k(Yk, Zk, Q)
            if lambdak > 0:
                alpha = min(alpha, 2 * beta * (1 - mu) / lambdak)'''

            Gamma = pinv(Q.t()@Q+delta2*(P.t()@P))@(Q.t() @ res)

            xk_bar = xk - alpha*(P@Gamma)
            rk_bar = res-alpha*(Q@Gamma)

            '''self._store_data(xk_bar)
            self._store_grad(-beta*rk_bar)
            optimizer.step(None)
            xk_next = self._gather_flat_data()'''
            xk_next = xk_bar+beta*rk_bar
            x_delta = xk_next - xk
            tmp = x_delta.dot(res)
            if tmp <= 0:
                #self._store_data(xk)
                #self._store_grad(flat_grad)
                optimizer.step(None)
                print("**Notice: (dir,res) <= 0", tmp)
            else:
                self._store_data(xk_next)

        cnt += 1

        state['step'] = cnt

        return loss
