import torch
from torch.optim import Optimizer

class AGM_base(Optimizer):
    def __init__(self, params, lr_tau, lr_eta, lr_alpha, lr_beta1, lr_beta2, momentum, weight_decay, dampening, q, eps, debug):
        if not 0.0 <= lr_tau:
            raise ValueError("Invalid learning rate Tau: {}".format(lr_tau))
        if not 0.0 <= lr_eta:
            raise ValueError("Invalid learning rate Eta: {}".format(lr_eta))
        if not 0.0 <= lr_alpha:
            raise ValueError("Invalid learning rate Alpha: {}".format(lr_alpha))
        if not 0.0 <= momentum:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(
            lr_tau      = lr_tau        ,
            lr_eta      = lr_eta        ,
            lr_alpha    = lr_alpha      ,
            lr_beta1    = lr_beta1      ,
            lr_beta2    = lr_beta2      ,
            momentum    = momentum      ,
            weight_decay= weight_decay  ,
            dampening   = dampening     ,
            q           = q             ,
            eps         = eps           ,
            debug       = debug
        )

        params = list(params)
        for p in params:
            p.z = p.data.clone()
            p.m = torch.zeros_like(p.data)

        super(AGM_base, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(AGM_base, self).__setstate__(state)

class AGM(AGM_base):
    def __init__(
        self                ,
        params              ,
        lr_tau      = 0.001 ,
        lr_eta      = 0.001 ,
        lr_alpha    = 0.001 ,
        lr_beta1    = 0.9   ,
        lr_beta2    = 0.99  ,
        momentum    = 0     ,
        weight_decay= 5e-4  ,
        dampening   = 0     ,
        q           = 3     ,
        eps         = 1e-8  ,
        debug       = False
    ):
        super(AGM, self).__init__(params, lr_tau, lr_eta, lr_alpha, lr_beta1, lr_beta2, momentum, weight_decay, dampening, q, eps, debug)

    def step(self, closure=None):
        if closure is not None:
            loss = closure()

        params_track = dict(z_dist=0., y_dist=0., update_dist=0., grad_dist=0., momentum_dist=0., pdata_dist=0., debug=False)
        
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            lr_beta1 = group['lr_beta1']
            lr_beta2 = group['lr_beta2']
            lr_tau = group['lr_tau']
            lr_eta = group['lr_eta']
            lr_alpha = group['lr_alpha']
            q = group['q']

            for p in group['params']:
                if p.grad is None: continue

                c = torch.add(lr_beta1 * p.m, p.grad.data, alpha=1 - lr_beta1)

                z_tmp = torch.mul(torch.pow(torch.abs(p.z), q - 2), p.z)
                z_tmp.add_(c, alpha=-lr_alpha)
                p.z = torch.mul(torch.sign(z_tmp), torch.pow(torch.abs(z_tmp), 1 / (q - 1)))
                del z_tmp

                p.data.mul_(1 - lr_tau - lr_eta * weight_decay).add_(c, alpha=lr_eta * (lr_tau - 1)).add_(p.z, alpha=lr_tau)
                del c

                p.m.mul_(lr_beta2).add_(p.grad.data, alpha=1 - lr_beta2)    
        
        if params_track['debug']:
            params_track['grad_dist'] = torch.sqrt(params_track['grad_dist'])
            params_track['pdata_dist'] = torch.sqrt(params_track['pdata_dist'])
            params_track['update_dist'] = torch.sqrt(params_track['update_dist'])
            params_track['y_dist'] = torch.sqrt(params_track['y_dist'])
            params_track['z_dist'] = torch.sqrt(params_track['z_dist'])
            params_track['momentum_dist'] = torch.sqrt(params_track['momentum_dist'])

        return params_track
    

class AGM_NE(AGM_base):
    def __init__(
        self                ,
        params              ,
        lr_tau      = 0.001 ,
        lr_eta      = 0.001 ,
        lr_alpha    = 0.001 , 
        lr_beta1    = 0.9   ,
        lr_beta2    = 0.99  ,
        momentum    = 0     ,
        eps         = 1e-8  ,
        weight_decay= 5e-4  ,
        dampening   = 0     ,
        q           = 3     ,
        debug       = False
    ):
        super(AGM_NE, self).__init__(params, lr_tau, lr_eta, lr_alpha, lr_beta1, lr_beta2, momentum, weight_decay, dampening, q, eps, debug)
    
    def step(self, closure=None):
        if closure is not None:
            loss = closure()

        params_track = dict(z_dist=0., y_dist=0., update_dist=0., grad_dist=0., momentum_dist=0., pdata_dist=0., debug=False)
        
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            lr_beta1 = group['lr_beta1']
            lr_beta2 = group['lr_beta2']
            lr_tau = group['lr_tau']
            lr_eta = group['lr_eta']
            lr_alpha = group['lr_alpha']
            q = group['q']
            eps = group['eps']

            for p in group['params']:
                if p.grad is None: continue

                # torch.nn.utils.clip_grad_norm_(p.grad.data, max_norm=5.0, norm_type=q)

                c = torch.add(lr_beta1 * p.m, p.grad.data, alpha=1 - lr_beta1)

                z_tmp = torch.mul(torch.pow(torch.abs(p.z), q - 2), p.z)
                z_tmp.add_(c, alpha=-lr_alpha)
                p.z = torch.mul(torch.sign(z_tmp), torch.pow(torch.abs(z_tmp), 1 / (q - 1)))
                del z_tmp

                c_tmp = torch.mul(c, torch.pow(torch.abs(c) + eps, (2 - q) / (q - 1)))

                p.data.mul_(1 - lr_tau - lr_eta * weight_decay).add_(p.z, alpha=lr_tau).add_(c_tmp, alpha=lr_eta * (lr_tau - 1))

                p.m.mul_(lr_beta2).add_(p.grad, alpha=1 - lr_beta2)

                del c_tmp, c

        if params_track['debug']:
            params_track['grad_dist'] = torch.sqrt(params_track['grad_dist'])
            params_track['pdata_dist'] = torch.sqrt(params_track['pdata_dist'])
            params_track['update_dist'] = torch.sqrt(params_track['update_dist'])
            params_track['y_dist'] = torch.sqrt(params_track['y_dist'])
            params_track['z_dist'] = torch.sqrt(params_track['z_dist'])
            params_track['momentum_dist'] = torch.sqrt(params_track['momentum_dist'])

        return params_track


class AGM_HASP(AGM_base):

    def __init__(
        self                ,
        params              ,
        lr_tau      = 0.001 ,
        lr_eta      = 0.001 ,
        lr_alpha    = 0.001 , 
        lr_beta1    = 0.9   ,
        lr_beta2    = 0.99  ,
        momentum    = 0     ,
        eps         = 1e-4  ,
        weight_decay= 5e-4  ,
        dampening   = 0     ,
        q           = 3     ,
        debug       = False
    ):
        super(AGM_HASP, self).__init__(params, lr_tau, lr_eta, lr_alpha, lr_beta1, lr_beta2, momentum, weight_decay, dampening, q, eps, debug)
    
    def step(self, closure=None):
        if closure is not None:
            loss = closure()

        params_track = dict(z_dist=0., y_dist=0., update_dist=0., grad_dist=0., momentum_dist=0., pdata_dist=0., debug=False)
        
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            lr_beta1 = group['lr_beta1']
            lr_beta2 = group['lr_beta2']
            lr_tau = group['lr_tau']
            lr_eta = group['lr_eta']
            lr_alpha = group['lr_alpha']
            q = group['q']
            eps = group['eps']

            for p in group['params']:
                if p.grad is None: continue

                # gradient clipping
                # torch.nn.utils.clip_grad_norm_(p.grad.data, max_norm=5.0, norm_type=q)

                c = torch.add(lr_beta1 * p.m, p.grad, alpha=1 - lr_beta1)
                p.z.add_(c, alpha=-lr_alpha)
                c_tmp = torch.mul(c, torch.pow(torch.abs(c) + eps, (2 - q) / (q - 1)))
                p.data.mul_(1 - lr_tau - lr_eta * weight_decay).add_(p.z, alpha=lr_tau).add_(c_tmp, alpha=lr_eta * (lr_tau - 1))

                p.m.mul_(lr_beta2).add_(p.grad, alpha=1 - lr_beta2)

                assert torch.isnan(p.grad.data).count_nonzero() == 0, "grad"
                assert torch.isnan(p.m).count_nonzero() == 0, "m"
                assert torch.isnan(c).count_nonzero() == 0, "c"
                assert torch.isnan(p.z).count_nonzero() == 0, "z"
                assert torch.isnan(c_tmp).count_nonzero() == 0, "c_tmp"

                del c_tmp, c

        if params_track['debug'] is True:
            params_track['grad_dist'] = torch.sqrt(params_track['grad_dist'])
            params_track['pdata_dist'] = torch.sqrt(params_track['pdata_dist'])
            params_track['update_dist'] = torch.sqrt(params_track['update_dist'])
            params_track['y_dist'] = torch.sqrt(params_track['y_dist'])
            params_track['z_dist'] = torch.sqrt(params_track['z_dist'])
            params_track['momentum_dist'] = torch.sqrt(params_track['momentum_dist'])

        return params_track

# class AGM_NE(AGM_base):
#     def __init__(
#         self                ,
#         params              ,
#         lr_tau      = 0.001 ,
#         lr_eta      = 0.001 ,
#         lr_alpha    = 0.001 , 
#         lr_beta1    = 0.9   ,
#         lr_beta2    = 0.99  ,
#         momentum    = 0     ,
#         eps         = 1e-8  ,
#         weight_decay= 5e-4  ,
#         dampening   = 0     ,
#         q           = 3     ,
#         debug       = False
#     ):
#         super(AGM_NE, self).__init__(params, lr_tau, lr_eta, lr_alpha, lr_beta1, lr_beta2, momentum, weight_decay, dampening, q, eps, debug)
    
#     def step(self, closure=None):
#         if closure is not None:
#             loss = closure()

#         params_track = dict(z_dist=0., y_dist=0., update_dist=0., grad_dist=0., momentum_dist=0., pdata_dist=0., debug=False)
        
#         for group in self.param_groups:
#             weight_decay = group['weight_decay']
#             lr_beta1 = group['lr_beta1']
#             lr_beta2 = group['lr_beta2']
#             lr_tau = group['lr_tau']
#             lr_eta = group['lr_eta']
#             lr_alpha = group['lr_alpha']
#             q = group['q']
#             eps = group['eps']

#             for p in group['params']:
#                 if p.grad is None: continue

#                 # torch.nn.utils.clip_grad_norm_(p.grad.data, max_norm=5.0, norm_type=q)

#                 c = lr_beta1 * p.m + (1.0 - lr_beta1) * p.grad.data

#                 # y_p = p.data - group['lr_eta'] * torch.div(m_p, torch.abs(m_p)**((group['q']-2)/(group['q']-1)))
#                 # update = (group['q'])* (torch.abs(p.data)**(group['q']-1)) * torch.sign(p.data) - group['lr_alpha'] * d_p

#                 # tmp = m_p * (torch.abs(m_p + eps)**((2.0 - group['q'])/(group['q']-1)))
#                 # tmp = tmp / (torch.norm(tmp, p=2))

#                 # y_p = p.data - group['lr_eta'] * torch.norm(tmp, p=1) * tmp
#                 y_p = p.data - lr_eta * c * (torch.abs(c + eps)**((2 - q)/(q - 1)))
#                 update = q * (torch.abs(p.z)**(q - 1)) * torch.sign(p.z) - lr_alpha * c
#                 p.z = (torch.abs(update/q)**(1 / (q - 1))) * torch.sign(update)
#                 p.data = lr_tau * p.z + (1 - lr_tau) * y_p - lr_eta * weight_decay * p.data
#                 del y_p, update

#                 if group['debug'] is True:
#                     params_track['debug'] = True
#                     params_track['grad_dist'] += torch.sum((p.grad_pre - d_p)**2)
#                     params_track['pdata_dist'] += torch.sum((p.data_pre - p.data) ** 2)
#                     params_track['update_dist'] += torch.sum((p.update_pre - update) ** 2)
#                     params_track['y_dist'] += torch.sum((p.y_pre - y_p) ** 2)
#                     params_track['z_dist'] += torch.sum((p.z_pre - z_p) **2 )
#                     params_track['momentum_dist'] += torch.sum((p.m_pre - m_p) ** 2)

#                     p.y_pre = y_p.clone()
#                     p.update_pre = update.clone()
#                     p.grad_pre = d_p.clone()
#                     p.data_pre = p.data.clone()

#                 p.m = lr_beta2 * c + (1 - lr_beta2) * p.grad.data
#                 del c

#         if params_track['debug']:
#             params_track['grad_dist'] = torch.sqrt(params_track['grad_dist'])
#             params_track['pdata_dist'] = torch.sqrt(params_track['pdata_dist'])
#             params_track['update_dist'] = torch.sqrt(params_track['update_dist'])
#             params_track['y_dist'] = torch.sqrt(params_track['y_dist'])
#             params_track['z_dist'] = torch.sqrt(params_track['z_dist'])
#             params_track['momentum_dist'] = torch.sqrt(params_track['momentum_dist'])

#         return params_track