"""
Parameter-scaled stochastic Polyak step-size
Adapted from https://github.com/IssamLaradji/sps/blob/master/sps/sps.py
"""


import numpy as np
import torch
import time
import copy


class PSSps(torch.optim.Optimizer):
    def __init__(self,
                 params,
                 lr=1.,
                 weight_decay=0.,
                 n_batches_per_epoch=500,
                 init_step_size=1e-4,
                 c=0.5,
                 gamma=2.0,
                 beta=0.999,
                 eta_max=None,
                 adapt_flag='smooth_iter',
                 fstar_flag=None,
                 eps=1e-8,
                 centralize_grad_norm=False,
                 centralize_grad=False):
        params = list(params)
        super().__init__(params, {'lr': lr, 'weight_decay': weight_decay})
        self.eps = eps
        self.c = c
        self.centralize_grad_norm = centralize_grad_norm
        self.centralize_grad = centralize_grad

        if centralize_grad:
            assert self.centralize_grad_norm is False

        self.eta_max = eta_max
        self.gamma = gamma
        self.beta = beta
        self.init_step_size = init_step_size
        self.adapt_flag = adapt_flag
        self.state['step'] = 0

        self.state['step_size'] = init_step_size
        self.step_size_max = 0.
        self.n_batches_per_epoch = n_batches_per_epoch

        self.state['n_forwards'] = 0
        self.state['n_backwards'] = 0
        self.fstar_flag = fstar_flag

    def step(self, closure=None, loss=None, batch=None):
        if loss is None and closure is None:
            raise ValueError('please specify either closure or loss')

        if loss is not None:
            if not isinstance(loss, torch.Tensor):
                loss = torch.tensor(loss)

        # increment step
        self.state['step'] += 1

        # get fstar
        if self.fstar_flag:
            fstar = float(batch['meta']['fstar'].mean())
        else:
            fstar = 0.

        # get loss and compute gradients
        if loss is None:
            loss = closure()
        else:
            assert closure is None, 'if loss is provided then closure should beNone'

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'exp_avg_sqgrad' not in state:
                    state['exp_avg_sqgrad'] = torch.zeros_like(p)
                exp_avg_sqgrad = state['exp_avg_sqgrad']
                exp_avg_sqgrad.mul_(self.beta).addcmul_(p.grad, p.grad, value=1 - self.beta)

                bias_correction = 1 - self.beta ** self.state['step']
                alpha = ((exp_avg_sqgrad / bias_correction).sqrt() + self.eps).sqrt()

                state['alpha'] = alpha
                p.data.mul_(alpha)
                p.grad.div_(alpha)

        # save the current parameters:
        param_list, grad_current = get_grad_list(self.param_groups, centralize_grad=self.centralize_grad)
        grad_norm = compute_grad_norm(grad_current, centralize_grad_norm=self.centralize_grad_norm)

        if grad_norm < 1e-8:
            step_size = 0.
        else:
            # adapt the step size
            if self.adapt_flag in ['constant']:
                # adjust the step size based on an upper bound and fstar
                step_size = (loss - fstar) / \
                    (self.c * (grad_norm)**2 + self.eps)
                if loss < fstar:
                    step_size = 0.
                else:
                    if self.eta_max is None:
                        step_size = step_size.item()
                    else:
                        step_size = min(self.eta_max, step_size.item())

            elif self.adapt_flag in ['smooth_iter']:
                # smoothly adjust the step size
                step_size = loss / (self.c * (grad_norm)**2 + self.eps)
                coeff = self.gamma**(1./self.n_batches_per_epoch)
                step_size = min(coeff * self.state['step_size'],
                                step_size.item())
            else:
                raise ValueError('adapt_flag: %s not supported' %
                                 self.adapt_flag)

            # update with step size
            sgd_update(param_list, step_size, grad_current)

        # update state with metrics
        self.state['n_forwards'] += 1
        self.state['n_backwards'] += 1
        self.state['step_size'] = step_size
        self.state['grad_norm'] = grad_norm.item()

        if torch.isnan(param_list[0]).sum() > 0:
            raise ValueError('Got NaNs')

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                alpha = state['alpha']
                p.data.div_(alpha)
                p.grad.mul_(alpha)

        return float(loss)

# utils
# ------------------------------
def compute_grad_norm(grad_list, centralize_grad_norm=False):
    grad_norm = 0.
    for g in grad_list:
        if g is None or (isinstance(g, float) and g == 0.):
            continue

        if g.dim() > 1 and centralize_grad_norm: 
            # centralize grads 
            g.add_(-g.mean(dim = tuple(range(1,g.dim())), keepdim = True))

        grad_norm += torch.sum(torch.mul(g, g))
    grad_norm = torch.sqrt(grad_norm)
    return grad_norm

def get_grad_list(param_groups, centralize_grad=False):
    param_list = []
    grad_list = []
    for group in param_groups:
        decay = group['weight_decay']
        for p in group['params']:
            g = p.grad
            if g is None:
                g = 0.
            else:
                g = p.grad.data
                if decay > 0:
                    g.add_(p.data, alpha=decay)
                if len(list(g.size()))>1 and centralize_grad:
                    # centralize grads
                    g.add_(-g.mean(dim = tuple(range(1,len(list(g.size())))), 
                           keepdim = True))

            param_list += [p]
            grad_list += [g]        
    
    return param_list, grad_list


def sgd_update(params, step_size, grad_current):
    for p, g in zip(params, grad_current):
        if isinstance(g, float) and g == 0.:
            continue
        p.data.add_(other=g, alpha=- step_size)
