#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import torch
from torch.optim import Optimizer

from typing import Iterable


def get_grad_list(params):
    return [p.grad for p in params]

def compute_grad_norm(grad_list):
    grad_norm = 0.
    for g in grad_list:
        if g is None:
            continue
        grad_norm += torch.sum(torch.mul(g, g))
    grad_norm = torch.sqrt(grad_norm)
    return grad_norm

def get_lr(state, params):
    return max([state[p]['lr'] for p in params])


##############################################################################
# SMB Optimizer
##############################################################################

class SMB(Optimizer):
    def __init__(
        self,
        params: Iterable[torch.nn.parameter.Parameter],
        lr: float = 0.5,
        eta: float = 0.99,
        maxiniter:  int = 1,
        c: float = 1e-2,
        n_batches_per_epoch = 500,
    ):
        defaults = dict(lr=lr, eta=eta, c=c)
        super().__init__(params, defaults)
        
    def step(self, closure):
        
        if len(self.state) == 0:
            self.state['step'] = 0
            
        loss = closure()
        loss.backward()
            
        #print(self.state['step'], 'loss:', loss.item())
        
        for group in self.param_groups:
            
            lr =  group["lr"]  
            eta =  group["eta"]
            c = group["c"]
            
            params = group["params"]
            
            grad_current = get_grad_list(params)
            grad_norm = compute_grad_norm(grad_current)
            
            cond = loss.item() - c * lr * grad_norm.pow(2).item()
            
            
            for p in params:
                
                state = self.state[p]
                
                # State initialization
                if len(state) == 0:
                    state["grad_old"] = torch.zeros_like(p.grad)
                    state["s_old"] = torch.zeros_like(p.data)
                
                if p.grad is None:
                    continue    
                    
                grad = p.grad.data
                state["grad_old"] = grad.clone().detach()
                
                if grad.is_sparse:
                    raise RuntimeError("SMB does not support sparse gradients")
                
                s_new = grad.mul(-lr)
                state["s_old"] = s_new.clone().detach()
                
                p.data.add_(s_new, alpha=1.0)
                
            
            loss_next = closure()
            
            #print(self.state['step'], 'loss_next:', loss_next.item())

            initer = 0
            if grad_norm >= 1e-8: 

                #print(self.state['step'], initer, loss_next.item(), loss.item(), loss.item() - c * lr * grad_norm.pow(2).item())

                if loss_next.item() <= cond:
                    #print(initer, end=" ")
                    break 
                        
                loss_next.backward()

                for p in group["params"]:

                    state = self.state[p]

                    grad = state["grad_old"]
                    grad_t = p.grad.data

                    s_old = state["s_old"]

                    g = torch.flatten(grad)
                    gt = torch.flatten(grad_t)
                    s = torch.flatten(s_old)

                    sg = torch.dot(s,g)  # (v6)
                    sgt = torch.dot(s,gt)

                    y_t = grad_t.sub(grad, alpha=1.0) # y^k_t = g_t - g

                    y = torch.flatten(y_t)
                    ys = sgt - sg       # v1 torch.dot(y,s)
                    ss = torch.dot(s,s) # v2
                    yy = torch.dot(y,y) # v3
                    yg = torch.dot(y,g) # v4
                    gg = torch.dot(g,g) # v5

                    sigma = 0.5*(torch.sqrt(ss)*(torch.sqrt(yy)+torch.sqrt(gg)/eta)-ys)
                    theta = (ys + 2.0*sigma)**2.0 - ss*yy

                    cg = -ss/(2.0*sigma) # cg(sigma)
                    cs = cg/theta*(-(ys + 2.0*sigma)*yg+yy*sg) # cs(sigma)
                    cy = cg/theta*(-(ys + 2.0*sigma)*sg+ss*yg) # cy(sigma)

                    s_new = s_old.mul(cs).add(grad, alpha=cg).add(y_t, alpha=cy)
                        
                    p.data.sub_(s_old, alpha=1.0).add_(s_new, alpha=1.0)    
                  
                initer += 1
                
        self.state['step'] += 1
                
        return loss, initer


############################################################################


##############################################################################
# SMBi Optimizer (H_k with new independent batch)
##############################################################################


# Build a model with a new batch (if the Armijo cond not satisfied) so that H'_k and g_k are independent

class SMBi(Optimizer):
    def __init__(
        self,
        params: Iterable[torch.nn.parameter.Parameter],
        lr: float = 0.5,
        eta: float = 0.99,
        maxiniter:  int = 1,
        c: float = 1e-2,
        n_batches_per_epoch = 500,
    ):
        defaults = dict(lr=lr, eta=eta, c=c)
        super().__init__(params, defaults)
        
    def step(self, closure):
        
        if len(self.state) == 0:
            self.state['step'] = 0
            self.state['model_step'] = False
            
        loss = closure()
        loss.backward()
            
        #print(self.state['step'], 'loss:', loss.item())
        
        for group in self.param_groups:
            
            lr =  group["lr"]  
            eta =  group["eta"]
            c = group["c"]
            
            params = group["params"]
            
            grad_current = get_grad_list(params)
            grad_norm = compute_grad_norm(grad_current)
            
            cond = loss.item() - c * lr * grad_norm.pow(2).item()
            
            initer = 0
            
            if not self.state['model_step']:

                for p in params:

                    state = self.state[p]

                    # State initialization
                    if len(state) == 0:
                        state["grad_prev"] = torch.zeros_like(p.grad)
                        state["grad_old"] = torch.zeros_like(p.grad)
                        state["s_old"] = torch.zeros_like(p.data)

                    if p.grad is None:
                        continue    

                    grad = p.grad.data
                    state["grad_prev"] = grad.clone().detach()

                    if grad.is_sparse:
                        raise RuntimeError("SMB does not support sparse gradients")

                    s_new = grad.mul(-lr)
                    state["s_old"] = s_new.clone().detach()

                    p.data.add_(s_new, alpha=1.0)

                loss_next = closure()
        
                #print(self.state['step'], 'loss_next:', loss_next.item())

                if grad_norm >= 1e-8: 

                    #print(self.state['step'], initer, loss_next.item(), loss.item(), loss.item() - c * lr * grad_norm.pow(2).item())

                    if loss_next.item() > cond:
                        
                        for p in params:
                            state = self.state[p]
                            p.data.sub_(state["s_old"], alpha=1.0)

                        self.state['model_step'] = True
             
            else:
                
                #print("Model Step!")
                
                for p in params:

                    state = self.state[p]

                    if p.grad is None:
                        continue    

                    grad = p.grad.data
                    state["grad_old"] = grad.clone().detach()

                    if grad.is_sparse:
                        raise RuntimeError("SMB does not support sparse gradients")

                    s_new = grad.mul(-lr)
                    state["s_old"] = s_new.clone().detach()

                    p.data.add_(s_new, alpha=1.0)

                loss_next = closure()
                loss_next.backward()

                for p in group["params"]:

                    state = self.state[p]
                    
                    grad_prev = state["grad_prev"]
                    grad = state["grad_old"]
                    grad_t = p.grad.data
                    s_old = state["s_old"]
                    
                    y_t = grad_t.sub(grad, alpha=1.0) # y^k_t = g_t - g
                    
                    g_prev = torch.flatten(grad_prev)
                    g = torch.flatten(grad)
                    gt = torch.flatten(grad_t)
                    s = torch.flatten(s_old)
                    y = torch.flatten(y_t)

                    #sgt = torch.dot(s,gt)
                    
                    ys = torch.dot(y,s) # v1
                    ss = torch.dot(s,s) # v2
                    yy = torch.dot(y,y) # v3
                    yg = torch.dot(y,g_prev) # v4
                    gg = torch.dot(g,g) # v5
                    sg = torch.dot(s,g_prev)  # (v6)

                    sigma = 0.5*(torch.sqrt(ss)*(torch.sqrt(yy)+torch.sqrt(gg)/eta)-ys)
                    theta = (ys + 2.0*sigma)**2.0 - ss*yy

                    cg = -ss/(2.0*sigma) # cg(sigma)
                    cs = cg/theta*(-(ys + 2.0*sigma)*yg+yy*sg) # cs(sigma)
                    cy = cg/theta*(-(ys + 2.0*sigma)*sg+ss*yg) # cy(sigma)

                    s_new = s_old.mul(cs).add(grad_prev, alpha=cg).add(y_t, alpha=cy)
                        
                    p.data.sub_(state["s_old"], alpha=1.0).add_(s_new, alpha=1.0)
                 
                self.state['model_step'] = False
                initer += 1
                
        self.state['step'] += 1
                
        return loss, initer


##############################################################################