import torch
from functools import reduce
from torch.optim.optimizer import Optimizer

#from torch import inverse as inv
#from torch import pinverse as pinv


#@profile
def mysvd(A):
    a,b,c=A[0,0],A[1,1],A[0,1]
    
    if c == 0:
        return A.diag(),torch.eye(2,device=A.device,dtype=A.dtype)

    Delta = ((a-b)**2+4*(c**2))**0.5
    s = torch.tensor([(a+b+Delta)/2,(a+b-Delta)/2],device=A.device,dtype=A.dtype)
    V = A.clone()
    #V[0,0],V[1,0],V[0,1],V[1,1] = Delta+a-b,2*c,Delta+b-a,-2*c
    v00,v10,v01,v11 = Delta+a-b,2*c,Delta+b-a,-2*c
    n1 = (v00**2+v10**2)**0.5
    n2 = (v01**2+v11**2)**0.5
    V[0,0],V[1,0],V[0,1],V[1,1]=v00/n1,v10/n1,v01/n2,v11/n2
    #V[:,1]/=V[:,1].norm()
    return s,V

#@profile
def pinv(A,rcond=1e-15):
    s,V=mysvd(A.double())
    s[0], s[1] = (1/s[0] if s[0]>1e-15 else 0), (1/s[1] if s[1]>1e-15 else 0)
    return ((V*s)@V.t()).to(A.dtype)

#@profile
def mymm(A,buf):
    #M = torch.empty((2,2),dtype=A.dtype,device=A.device)
    p,q = A[:,0],A[:,1]
    a = p.dot(p)
    b = q.dot(q)
    c = p.dot(q)
    buf[0,0],buf[0,1],buf[1,0],buf[1,1] = a,c,c,b
    return buf

#@profile
def mymtv(A,b):
    r = torch.empty(2,dtype=A.dtype,device=A.device)
    r[0], r[1] = torch.dot(A[:,0],b),torch.dot(A[:,1],b)
    return r

def lambda_k(Yk,Zk,Rk):
    YR = torch.cat((Yk,Rk),1)
    Z_inv = pinv(Zk)
    m = Z_inv.size(1)
    G = torch.zeros(2*m,2*m,dtype=Zk.dtype,device=Zk.device)
    G[0:m,m:2*m].copy_(Z_inv)
    G[m:2*m,0:m].copy_(Z_inv)
    F = (YR.t()@YR)@G
    eigs,_ = torch.eig(F)
    t = eigs[:,0]
    return max(t)

class PCA(Optimizer):

    def __init__(self, optimizer=None,alpha=1.0,beta=1.0,
                 damp1=1,damp2=1e-2,mu=0,precision=0):
        if optimizer is None:
            raise ValueError("optimizer cannot be None")
        if alpha < 0.0:
            raise ValueError("Invalid Conjugate Anderson alpha parameter: {}".format(alpha))
        if beta < 0.0:
            raise ValueError("Invalid Conjugate Anderson beta parameter: {}".format(beta))        

        self.optimizer = optimizer
        self.beta = beta
        self.alpha = alpha
        self.damp1 = damp1
        self.damp2 = damp2
        self.mu = mu
        self.precision = precision

        self.param_groups = self.optimizer.param_groups
        self.state = self.optimizer.state
        self.defaults = self.optimizer.defaults

        if len(self.param_groups) != 1:
            raise ValueError("Conjugate Anderson doesn't support per-parameter options "
                             "(parameter groups)")

        self._params = self.param_groups[0]['params']
        self._numel_cache = None

        N = self._numel()
        device = self._params[0].device
        dtype = self._params[0].dtype
        
        if self.precision == 1:
            dtype = torch.float64
        
        state = self.state

        state.setdefault('step', 0)
        state.setdefault('P', torch.zeros((N, 2), device=device, dtype=dtype))
        state.setdefault('Q', torch.zeros((N, 2), device=device, dtype=dtype))
        state.setdefault('x_prev', torch.zeros((N,1), dtype=dtype, device=device))
        state.setdefault('res_prev', torch.zeros((N,1), dtype=dtype, device=device))
        #state.setdefault('x_back', torch.zeros(N, dtype=dtype, device=device))
        state['Q_mul'], state['P_mul'] = torch.zeros((2,2),dtype=dtype,device=device), torch.zeros((2,2),dtype=dtype,device=device)

    def __setstate__(self, state):
        super(PCA, self).__setstate__(state)

    def _numel(self):
        if self._numel_cache is None:
            self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
        return self._numel_cache

    def _gather_flat_grad(self):
        views = []
        for p in self._params:
            if p.grad is None:
                view = p.new(p.numel()).zero_()
            elif p.grad.is_sparse:
                view = p.grad.to_dense().view(-1)
            else:
                view = p.grad.view(-1)
            views.append(view)
        return torch.unsqueeze(torch.cat(views, 0),1)

    def _gather_flat_data(self):
        views = []
        for p in self._params:
            views.append(p.data.view(-1))
        return torch.unsqueeze(torch.cat(views,0),1)

    def _store_data(self,other):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.copy_(other[offset:offset + numel].view_as(p))
            offset += numel
        assert offset == self._numel()

    def _store_grad(self,other):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            if p.grad == None:
                p.grad = other[offset:offset + numel].view_as(p).detach()
            else:
                p.grad.copy_(other[offset:offset + numel].view_as(p).detach())
            offset += numel
        assert offset == self._numel()

    def _add_grad(self, step_size, update):
        offset = 0
        for p in self._params:
            numel = p.numel()
            # view as to avoid deprecated pointwise semantics
            p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
            offset += numel
        assert offset == self._numel()

    def _directional_evaluate(self, closure, x,g, t, d):
        self._add_grad(t, d)
        loss = closure()
        xk = self._gather_flat_data()
        flat_grad = self._gather_flat_grad()
        self._store_data(x)
        self._store_grad(g)
        return loss, xk, flat_grad

    def setfullgrad(self,length):
        self.fullgrad = self._gather_flat_grad().div(length)

    def settmpx(self):
        self.xk = self._gather_flat_data()

    def _get_x_delta(self,Xk,Rk,delta):
        Q, R, G = simpleQR(Rk)
        Xk = Xk.mm(G)
        Rk = Rk.mm(G)
        H_inv = inv(R + delta * inv(R).t() @ (Xk.t() @ Xk))
        Gamma = H_inv @ (Q.t() @ res)
        x_delta = beta * res - (alpha * Xk + alpha * beta * Rk) @ Gamma
        return x_delta

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        assert len(self.param_groups) == 1

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        group = self.param_groups[0]
        beta = self.beta
        alpha = self.alpha
        damp1 = self.damp1
        damp2 = self.damp2
        mu = self.mu
        optimizer=self.optimizer
        N = self._numel()
        device = self._params[0].device
        dtype = self._params[0].dtype
        if self.precision == 1:
            dtype = torch.float64      

        state = self.state

        xk = self._gather_flat_data().to(dtype).to(device)
        flat_grad = self._gather_flat_grad()
        weight_decay = group['weight_decay']
        res = flat_grad.add(alpha=weight_decay,other=xk).neg().to(dtype).to(device)

        P, Q = state['P'].to(device), state['Q'].to(device)
        res_prev, x_prev = state['res_prev'].to(device), state['x_prev'].to(device)

        cnt = state['step']

        eps = 1e-8
        if cnt == 0:
            P.zero_()
            Q.zero_()
        else:
            k = cnt % 2
            #print('##########')
            
            p = xk-x_prev
            q = res-res_prev

            delta1 = damp1*torch.dot(res[0],res[0])/(torch.dot(p[0],p[0])+eps).to(device)
            Q_mul, P_mul = state['Q_mul'].to(device), state['P_mul'].to(device)
            
            xi = -pinv(Q_mul + delta1 * P_mul) @ (Q.t()@q)
            #print(Q_mul.is_cuda, delta1.is_cuda, P_mul.is_cuda, Q.is_cuda, q.is_cuda, xi.is_cuda)
            #print(q.size(), Q.size(), xi.size())

            q += Q@xi
            #print(q.size(), Q.size(), xi.size())
            #q += Q[:,0]*xi[0] + Q[:,1]*xi[1]
            p += P@xi
            #p += P[:,0]*xi[0] + P[:,1]*xi[1]

            P[:,k].copy_(p[0])
            Q[:,k].copy_(q[0])

        x_prev.copy_(xk)
        res_prev.copy_(res)

        if cnt == 0:
            optimizer.step(None)
        else:
            delta2 = damp2 * torch.dot(res[0],res[0])/(torch.dot(p[0],p[0])+eps).to(device)
            #Q_mul, P_mul = state['Q_mul'], state['P_mul'] = Q.t()@Q, P.t()@P
            Q_mul, P_mul = mymm(Q,state['Q_mul']).to(device), mymm(P,state['P_mul']).to(device)
            #Q_mul, P_mul = state['Q_mul'], state['P_mul'] = torch.mm(Q.t(),Q), torch.mm(P.t(),P)
         
            
            Gamma = pinv(Q_mul+delta2*P_mul)@(Q.t() @ res)
            xk_bar = xk - alpha*(P@Gamma)
            rk_bar = res-alpha*(Q@Gamma)
            #xk_bar = xk - (P@Gamma)
            #rk_bar = res- (Q@Gamma)


            self._store_data(xk_bar)
            flat_grad_bar=(-rk_bar).add(alpha = -weight_decay, other=xk)
            # self._store_grad(-rk_bar)
            self._store_grad(flat_grad_bar)
            optimizer.step(None)
            xk_next = self._gather_flat_data()

            x_delta = xk_next - xk
            x_delta *= beta
            #print(x_delta.size())
            tmp = torch.dot(x_delta[0], res[0])
            #x_delta.dot(res[0])
            if tmp <= 0:
                self._store_data(xk)
                self._store_grad(flat_grad)
                optimizer.step(None)
                #print("**Notice: (dir,res) <= 0", tmp)
            else:
                self._store_data(xk+x_delta)

        cnt += 1

        state['step'] = cnt

        return loss
