import math
import torch
import numpy as np
from torch.optim.optimizer import Optimizer
import bisect

class SGD_LRBand(Optimizer):
    r"""Implements stochastic gradient descent (optionally with momentum)
    with several step size decay band schemes:
        1. 1/t band decay
        2. 1/sqrt(t) band decay
        3. step-decay band

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups.
        scheme (str): the decay scheme, currently only supports {'1t_band', '1sqrt_band', 'step_band'}.
        eta0 (float): initial learning rate.
        alpha (float): decay factor.
        milestones (list): a list denoting which time to decrease the stepsize.
        T_max: total number of steps.
        momentum (float, optional): momentum factor (default: 0).
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0).
        dampening (float, optional): dampening for momentum (default: 0).
        nesterov (bool, optional): enables Nesterov momentum (default: False).
    """

    def __init__(self, params, scheme, step_mode, epoch_mode, eta0, alpha, ratio, milestones=[], T_max=0,
                 n_batches_per_epoch=1,
                 momentum=0, dampening=0, weight_decay=0, nesterov=False):

        if eta0 < 0.0:
            raise ValueError("Invalid eta0 value: {}".format(eta0))
        if alpha < 0.0:
            raise ValueError("Invalid alpha value: {}".format(alpha))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight decay: {}".format(weight_decay))

        defaults = dict(momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        # if nesterov and (momentum <= 0 or dampening != 0):
        #    raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD_LRBand, self).__init__(params, defaults)

        self.eta0 = eta0
        self.alpha = alpha
        self.ratio = ratio
        self.milestones = [int(x) for x in milestones]
        self.last_round = -1
        self.cur_round = 0
        self.cur_lr = eta0
        self.T_max = T_max
        self.left_barrier = 0
        self.right_barrier = 0
        self.width = 0
        self.upper_ratio = ratio * eta0
        self.n_batches_per_epoch = n_batches_per_epoch
        self.scheme = scheme
        self.step_mode = step_mode
        self.band_weight1 = eta0
        self.band_weight2 = 1
        self.num_count = 0
        self.idx = 0
        self.epoch_mode = epoch_mode

    def __setstate__(self, state):
        super(SGD_LRBand, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)


    def get_lr_1sqrt_band(self):
        if self.cur_round < self.milestones[0]:
            self.cur_lr = self.eta0 / (1 + self.alpha*self.cur_epoch ** 0.5)
        else:
            if self.step_mode == '1t':
                if self.cur_round in self.milestones:
                    self.band_weight2 = (self.right_barrier * (
                            self.alpha * self.left_barrier ** 0.5 + 1) - self.ratio * self.left_barrier * (
                                                 self.alpha * self.right_barrier ** 0.5 + 1)) / (
                                                self.ratio * self.alpha * self.right_barrier ** 0.5 + self.ratio - 1 - self.alpha * self.left_barrier ** 0.5)
                    self.band_weight1 = self.eta0 * (self.band_weight2 + self.right_barrier) / (
                            self.alpha * self.right_barrier ** 0.5 + 1)
                self.cur_lr = self.band_weight1 / (self.num_count + self.band_weight2)
            elif self.step_mode == '1sqrt':
                if self.cur_round in self.milestones:
                    self.band_weight2 = ((self.left_barrier * self.right_barrier) ** 0.5 * self.alpha * (
                                1 - self.ratio) + self.right_barrier ** 0.5 - self.ratio * self.left_barrier ** 0.5) / (
                                                self.ratio * (
                                                self.alpha * self.right_barrier ** 0.5 + 1) - self.alpha * self.left_barrier ** 0.5 - 1)
                    self.band_weight1 = self.ratio * (self.right_barrier ** 0.5 + self.band_weight2) / (
                            self.alpha * self.right_barrier ** 0.5 + 1)
                self.cur_lr = self.band_weight1 / (self.num_count ** 0.5 + self.band_weight2)
            elif self.step_mode == 'linear':
                if self.cur_round in self.milestones:
                    self.band_weight1 = -(
                            self.ratio * self.eta0 / (self.alpha * self.left_barrier ** 0.5 + 1) - self.eta0 / (
                            self.alpha * self.right_barrier ** 0.5 + 1)) / self.width
                    self.band_weight2 = self.eta0 / (
                            self.alpha * self.right_barrier ** 0.5 + 1) - self.band_weight1 * self.right_barrier
                self.cur_lr = self.band_weight1 * self.num_count + self.band_weight2
            else:
                print('Invalid step mode')

        return self.cur_lr

    def get_lr_stepdecay_band(self):
        if self.cur_round < self.milestones[0]:
            self.cur_lr = self.eta0
        else:
            if self.step_mode == '1t':
                if self.cur_round in self.milestones:
                    eta_max = self.ratio * self.eta0 * self.alpha ** (self.idx - 1)
                    eta_min = self.eta0 * self.alpha ** self.idx
                    self.band_weight2 = (self.right_barrier * eta_min - self.left_barrier * eta_max) / (
                                eta_max - eta_min)
                    self.band_weight1 = eta_min * (self.right_barrier + self.band_weight2)
                self.cur_lr = self.band_weight1 / (self.num_count + self.band_weight2)
            elif self.step_mode == '1sqrt':
                if self.cur_round in self.milestones:
                    eta_max = self.ratio * self.eta0 * self.alpha ** (self.idx - 1)
                    eta_min = self.eta0 * self.alpha ** self.idx
                    self.band_weight2 = (self.right_barrier ** 0.5 * eta_min - self.left_barrier ** 0.5 * eta_max) / (
                                eta_max - eta_min)
                    self.band_weight1 = eta_min * (self.right_barrier ** 0.5 + self.band_weight2)
                self.cur_lr = self.band_weight1 / (self.num_count ** 0.5 + self.band_weight2)
            elif self.step_mode == 'linear':
                if self.cur_round in self.milestones:
                    eta_max = self.ratio * self.eta0 * self.alpha ** (self.idx - 1)
                    eta_min = self.eta0 * self.alpha ** self.idx
                    self.band_weight1 = (eta_max - eta_min) / (self.left_barrier - self.right_barrier)
                    self.band_weight2 = eta_max - self.band_weight1 * self.left_barrier
                self.cur_lr = self.band_weight1 * self.num_count + self.band_weight2
            elif self.step_mode == 'cosine':
                if self.cur_round in self.milestones:
                    eta_max = self.ratio * self.eta0 * self.alpha ** (self.idx - 1)
                    eta_min = self.eta0 * self.alpha ** self.idx
                    self.band_weight1 = (eta_max + eta_min)/2
                    self.band_weight2 = (eta_max - eta_min)/2
                self.cur_lr = self.band_weight1 + self.band_weight2*math.cos((self.num_count-self.left_barrier)*math.pi/self.width)
            else:
                print('Invalid step mode')

        return self.cur_lr


    def get_lr_func(self):
        ##### compute the current learning rate

        self.idx = bisect.bisect(self.milestones, self.cur_round)

        if self.epoch_mode == 1:  ### this is means that we locally decay the step-size every epoch.
            self.num_count = self.cur_epoch
            if self.last_round == -1 and self.cur_round == 1:
                self.left_barrier = 1
                self.right_barrier = self.milestones[0] // self.n_batches_per_epoch
                self.width = 1 +  self.right_barrier - self.left_barrier
            else:
                if self.cur_round in self.milestones and self.cur_round < self.milestones[len(self.milestones) - 1]:
                    self.left_barrier = 1 + self.milestones[self.idx - 1] // self.n_batches_per_epoch
                    self.right_barrier = self.milestones[self.idx] // self.n_batches_per_epoch
                    self.width = 1 +  self.right_barrier - self.left_barrier
                    print(self.left_barrier, self.right_barrier)
                #elif self.cur_round >= self.milestones[len(self.milestones) - 1]:
                #    return self.cur_lr
        else:
            self.num_count = self.cur_round
            if self.last_round == -1 and self.cur_round == 1:
                self.left_barrier = 1
                self.right_barrier = self.milestones[0]
                self.width = 1 + self.right_barrier - self.left_barrier
            else:
                if self.cur_round in self.milestones and self.cur_round < self.milestones[len(self.milestones) - 1]:
                    self.left_barrier = 1 + self.milestones[self.idx - 1]
                    self.right_barrier = self.milestones[self.idx]
                    self.width = 1 + self.right_barrier - self.left_barrier
                #elif self.cur_round >= self.milestones[len(self.milestones) - 1]:
                #    return self.cur_lr
        ####### update the learning rate        
        self.last_round = self.cur_round
        if self.scheme == '1t_band':
            self.cur_lr = self.get_lr_1t_band()
        elif self.scheme == '1sqrt_band':
            self.cur_lr = self.get_lr_1sqrt_band()
        elif self.scheme == 'step_band':
            self.cur_lr = self.get_lr_stepdecay_band()
        else:
            print('Invalid scheme')

        return self.cur_lr

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        self.cur_round += 1
        self.cur_epoch = 1 + self.cur_round // self.n_batches_per_epoch
        self.cur_lr = self.get_lr_func()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-self.cur_lr, d_p)

        return loss
