import torch
import numpy as np
import math
import scipy 
from scipy import linalg
from numpy import linalg as LA
from pyblas.level1 import dnrm2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class pytorchAA:
    """Anderson acceleration as described by Walker and Ni in doi:10.2307/23074353."""

    def __init__(self, dimension, depth, type2=True, reg=1e-8):

        self._dimension = dimension
        self._depth = depth
        self.reg = reg * torch.eye(self._depth, device='cpu')
        self.Y = torch.zeros((self._dimension, self._depth), device = device) # changes in increments
        self.S = torch.zeros((self._dimension, self._depth), device = device) # changes in fixed point applications
        self.xTx = torch.zeros((self._depth, self._depth), device = device) 
        self.it =0
        if type2==False:
            self.apply = self.type1
        else:
            self.apply = self.type2
            
    def reset(self):
        self._Fk = torch.zeros((self._dimension, self._depth), requires_grad=False, device = device) # changes in increments
        self._Gk = torch.zeros((self._dimension, self._depth), requires_grad=False, device = device) # changes in fixed point applications
        self.it = 0
            
    def type2(self, x , fx ):
        mk = min(self.it, self._depth)
        g = x - fx
        if (mk > 0):
            col = (self.it -1) % self._depth
            y = g -self.gprev
            self.S[:,col] = x - self.xprev
            self.Y[:,col] = y
            self.xTx[col,:] = self.xTx[:,col] = y.matmul(self.Y)
            b = self.Y.t().mv(g)
            r = (torch.norm(self.Y) **2 + torch.norm(self.S) **2).cpu()
            lstsq_solution = linalg.lstsq(self.xTx.cpu()+ self.reg * r, b.cpu())
            gamma = torch.tensor(lstsq_solution[0]).cuda()
            xkp1 = fx - (self.S-self.Y) @ gamma.float()
            self.it +=1
            self.xprev = x.detach().clone()
            self.gprev = x-fx
        else:
            xkp1 = fx
            self.it +=1
            self.xprev = x.detach().clone()
            self.gprev = x-fx
        return xkp1
    
    def type1(self, x , fx ):
        mk = min(self.it, self._depth)
        g = x - fx
        if (mk > 0):
            col = (self.it -1) % self._depth
            s = x - self.xprev
            y = g -self.gprev
            self.S[:,col] = s
            self.Y[:,col] = y
            # self.xTx[col,:] = self.xTx[:,col] = s.matmul(self.Y)
            self.xTx[:,col] = y.matmul(self.S).t()
            self.xTx[col, :] = s.matmul(self.Y)
            # self.xTx = self.S.t()@self.Y
            b = self.S.t().mv(g)
            r = (torch.norm(self.Y) **2 + torch.norm(self.S) **2).cpu()
            # print(r)
            lstsq_solution = linalg.lstsq(self.xTx.cpu()+ self.reg * r, b.cpu())
            gamma = torch.tensor(lstsq_solution[0]).cuda()
            xkp1 = fx - (self.S-self.Y) @ gamma.float()
            self.it +=1
            self.xprev = x.detach().clone()
            self.gprev = x-fx
        else:
            xkp1 = fx
            self.it +=1
            self.xprev = x.detach().clone()
            self.gprev = x-fx
        return xkp1
    