import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer
import math

class PolyakOptimizer(Optimizer):
    def __init__(self, params, max_lr=math.inf):
        self.max_lr = max_lr
        self.old_lr = 0.0
        defaults = dict(lr=1.0)
        super(PolyakOptimizer, self).__init__(params, defaults)
                
    @torch.no_grad()
    def step(self, loss):

        # compute l2-norm
        grad_norm = 0.0
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    grad_norm += (p.grad.data**2).sum()
                
        lr = min(loss / grad_norm, self.max_lr)
        self.old_lr = lr
        
        for group in self.param_groups:
            lr_scheduler = group["lr"]
            
            for p in group["params"]:
                if p.grad is not None:
                    p.data -= lr_scheduler * lr * p.grad.data

        return loss
    
    
class AdaSPSOptimizer(Optimizer):
    def __init__(self, params):
        self.old_lr = 1.0
        self.total_loss = 0.0
        
        defaults = dict(lr=1.0)
        super(AdaSPSOptimizer, self).__init__(params, defaults)
                
    @torch.no_grad()
    def step(self, loss):

        # compute l2-norm
        grad_norm = 0.0
        
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    grad_norm += (p.grad.data**2).sum()
                

        self.total_loss += loss
        lr = min(loss / grad_norm / torch.sqrt(self.total_loss), self.old_lr)
        self.old_lr = lr
        
        for group in self.param_groups:
            lr_scheduler = group["lr"]
            
            for p in group["params"]:
                if p.grad is not None:                    
                    p.data -= lr_scheduler * lr * p.grad.data
            
        return loss

    
class DecSPSOptimizer(Optimizer):
    def __init__(self, params):
        self.old_lr = 1.0
        self.n_itr = 0

        defaults = dict(lr=1.0)
        super(DecSPSOptimizer, self).__init__(params, defaults)
                
    @torch.no_grad()
    def step(self, loss):

        # compute l2-norm
        grad_norm = 0.0
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    grad_norm += (p.grad.data**2).sum()
                
        if self.n_itr == 0:
            lr = loss / grad_norm
            self.old_lr = lr
        else:
            lr = min(loss / grad_norm, math.sqrt(self.n_itr) * self.old_lr) / math.sqrt(self.n_itr + 1.)
            self.old_lr = lr        
        
        for group in self.param_groups:
            lr_scheduler = group["lr"]
            
            for p in group["params"]:
                if p.grad is not None:
                    p.data -= lr_scheduler * lr * p.grad.data

        self.n_itr += 1            
        return loss


class InexactPolyakOptimizer(Optimizer):
    def __init__(self, params, total_iteration):
        self.total_iteration = total_iteration
        self.old_lr = 0.0
        
        defaults = dict(lr=1.0)
        super(InexactPolyakOptimizer, self).__init__(params, defaults)
                
    @torch.no_grad()
    def step(self, loss):

        # compute l2-norm
        grad_norm = 0.0
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    grad_norm += (p.grad.data**2).sum()
                
        lr = loss / grad_norm / math.sqrt(self.total_iteration)
        self.old_lr = lr
        
        for group in self.param_groups:
            lr_scheduler = group["lr"]
            
            for p in group["params"]:
                if p.grad is not None:
                    p.data -= lr_scheduler * lr * p.grad.data
        
        return loss


class LInexactPolyakOptimizer(Optimizer):
    def __init__(self, params, total_iteration):
        self.total_iteration = total_iteration
        self.old_lr = 0.0
        
        defaults = dict(lr=1.0)
        super(LInexactPolyakOptimizer, self).__init__(params, defaults)
                
    @torch.no_grad()
    def step(self, loss):

        for group in self.param_groups:
            lr_scheduler = group["lr"]
            
            for p in group["params"]:
                if p.grad is not None:
                    lr = loss / (p.grad.data**2).sum() / math.sqrt(self.total_iteration)
                    p.data -= lr_scheduler * lr * p.grad.data
        
        return loss
