import torch
from torch.optim.optimizer import Optimizer, required
from scipy import linalg

def solve_a(Rt: torch.Tensor, sigma_2: float):
    beta = 0.0
    A = linalg.toeplitz(Rt.cpu().numpy())
    B = Rt.cpu().numpy()
    B[0] = max(0, B[0]-sigma_2)
    an = torch.tensor(linalg.solve(A, B)).to(Rt)
    an = an/an.sum()
    an = an*(1-beta)
    an[0] += beta 
    return an/an.sum()

class LMSSGD(Optimizer):
    def __init__(self, params, lr=required, dampening=0,
                 weight_decay=0, sigma = .0, n = 3, beta = 0.9, beta2 = None):
        defaults = dict(lr=lr, dampening=dampening,
                        weight_decay=weight_decay, sigma = sigma, n = n, beta = beta, beta2 = beta2)
        # if nesterov and (momentum <= 0 or dampening != 0):
            # raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(LMSSGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(LMSSGD, self).__setstate__(state)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        EPSILON = 1e-6

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            dampening = group['dampening']
            n = group['n']
            beta = group['beta']
            sigma = group['sigma']**2
            beta2 = group['beta2']
            first = True
            for p in group['params']:
                if p.grad is None:
                    continue
                if weight_decay != 0:
                    p.data.mul_(1 - group['lr'] * weight_decay)
                d_p = p.grad.data
                param_state = self.state[p]
                if beta2 is not None:
                    if 'exp_avg_sq' not in param_state:
                        param_state['exp_avg_sq'] = torch.zeros_like(d_p).to(d_p)
                        param_state['ct'] = torch.tensor(0).to(d_p)
                    param_state['exp_avg_sq'].mul_(beta2).addcmul_(d_p, d_p, value=1 - beta2)
                    param_state['ct'].mul_(beta2).add_(1 - beta2)
                if n > 0:
                    # d[t] = b[0]g[t] + b[1]g[t-1] + ... + b[n]g[t-n]
                    if 'Rgg' not in param_state:
                        # first iteration, initialize Rgg
                        size = [n-1, d_p.numel()]
                        param_state['g_tau'] = torch.zeros(size, dtype=d_p.dtype).to(d_p)
                        param_state['g_tau'][0] = d_p.view(-1).clone()
                        param_state['Rgg'] = torch.zeros(n).to(d_p)
                        param_state['Rgg'][0] = d_p.norm().pow(2)
                    else:
                        # other iterations, first compute optimal a_n's, update buffer
                        Rgg = torch.mul(param_state['g_tau'], d_p.view(-1)).sum(1)
                        param_state['Rgg'][1:] = beta * param_state['Rgg'][1:] + (1-beta) * Rgg
                        param_state['Rgg'][0] = beta * param_state['Rgg'][0] + (1-beta) * d_p.norm().pow(2)
                        sigma_2 = sigma*d_p.numel()
                        an = solve_a(param_state['Rgg'], sigma_2)
                        if first:
                            print(param_state['Rgg'], an)
                            first = False

                        # print(param_state['g_tau'].size())
                        g_temp = d_p.view(1,-1).clone()
                        d_p.mul_(an[0])
                        d_p.add_(torch.einsum('i,ij->j', an[1:], param_state['g_tau']).view(d_p.size()))
                        param_state['g_tau'] = torch.cat((g_temp, param_state['g_tau'][:-1]))
                if beta2 is not None:
                    denom = param_state['exp_avg_sq'].div(param_state['ct']).sqrt().clamp_min(EPSILON)
                    p.data.addcdiv_(d_p, denom, value = -group['lr'])
                else:
                    p.data.add_(d_p, alpha = -group['lr'])
                # p.data.add_(-group['lr'], d_p)
        return loss
    

class LMSSGD_step():
    def __init__(self, a, b) -> None:
        self.a = a
        self.b = b
        self.first_batch = True
        # assert len(a) == len(b) and len(a) == order + 1
    
    @torch.no_grad()
    def step(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                if self.first_batch:
                    for o in range(len(self.a)-1):
                        a_buff = 'a_buffer_'+str(o)
                        if o == 0:
                            setattr(param, a_buff, param.grad.clone())
                        else:
                            setattr(param, a_buff, torch.zeros_like(param.grad))
                    for o in range(len(self.b)-1):
                        b_buff = 'b_buffer_'+str(o)
                        if o == 0:
                            setattr(param, b_buff, param.grad.clone())
                        else:
                            setattr(param, b_buff, torch.zeros_like(param.grad))
                else:
                    A_buff = torch.zeros_like(param.grad)
                    B_buff = param.grad.clone()
                    for o in range(len(self.a)-1):
                        a_buff = 'a_buffer_'+str(o)
                        A_buff_t = getattr(param, a_buff).clone()
                        setattr(param, a_buff, A_buff.clone())
                        A_buff = A_buff_t
                        param.grad -= A_buff*self.a[o+1]
                    for o in range(len(self.b)-1):
                        b_buff = 'b_buffer_'+str(o)
                        B_buff_t = getattr(param, b_buff).clone()
                        setattr(param, b_buff, B_buff.clone())
                        B_buff = B_buff_t
                        param.grad += B_buff*self.b[o+1]
                    del A_buff
                    del B_buff
                    del A_buff_t
                    del B_buff_t
                param.a_buffer_0 = param.grad.clone()
        return model