import numpy as np
import torch
from torch.optim.optimizer import Optimizer


class SGD_TLR(Optimizer):
    def __init__(self, params, lr=1e-4, momentum=0, dampening=0, weight_decay=0, nesterov=False,
                 batches_per_epoch=100,
                 meta_update=.33, meta_adjust=True, meta_level=0, meta_bound=1.25, meta_lr=1e-3):

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov,
                        meta_lr=meta_lr)

        self.meta_adjust = meta_adjust
        self.meta_level = meta_level
        self.meta_update = int(meta_update * batches_per_epoch)
        self.bound = meta_bound

        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")

        self.cnt = 0
        super(SGD_TLR, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGD_TLR, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """

        self.cnt += 1

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            #bb_num = .0
            #bb_den = .0
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue

                ################
                param_state = self.state[p]

                d_p = p.grad

                if weight_decay != 0:
                    d_p = d_p.add(p, alpha=weight_decay)

                if 'momentum_buffer' not in param_state:
                    buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                else:
                    buf = param_state['momentum_buffer']
                    buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

                if nesterov:
                    d_p = d_p.add(buf, alpha=momentum)
                else:
                    d_p = buf

                if self.meta_level == 0:
                    lr = group['lr']
                elif self.meta_level == 1:
                    if 'lr' not in param_state:
                        param_state['lr'] = group['lr']
                    lr = param_state['lr']
            
                p.add_(d_p, alpha=-lr)

        if (self.cnt % self.meta_update == 0):
            self.update()

        return loss

    def update(self):

        level = self.meta_level
        adjust = self.meta_adjust

        for group in self.param_groups:

            if level == 0:
                dlr = 0.0
                if adjust:
                    dlr2 = 0.0
                    dlr2_old = 0.0

            for p in group['params']:
                if p.grad is None:
                    continue

                param_state = self.state[p]

                if 'previous' not in param_state:
                    param_state['previous'] = torch.clone(p).detach()
                    param_state['stored_buf'] = torch.zeros_like(p)
                    continue

                old_grad = param_state['stored_buf']
                new_grad = torch.clone(param_state['previous'] - p).detach()

                if level == 0:
                    new_grad /= group["lr"]
                elif level == 1:
                    new_grad /= param_state["lr"]

                param_state['stored_buf'] = torch.clone(new_grad).detach()
                param_state['previous'] = torch.clone(p).detach()

                prod = (old_grad * new_grad).view(-1)

                tdlr = prod.sum().item() / (self.meta_update ** 2)
                if adjust:
                    tdlr2 = (new_grad ** 2).sum().item() / (self.meta_update ** 2)
                    tdlr2_old = (old_grad ** 2).sum().item() / (self.meta_update ** 2)

                if level == 0:
                    dlr += tdlr
                    if adjust:
                        dlr2 += tdlr2
                        dlr2_old += tdlr2_old

                elif level == 1:
                    dlr, dlr2, dlr2_old = tdlr, tdlr2, tdlr2_old


                    if adjust:
                        #meta_lr = .25 / (max((dlr2 - dlr),  dlr) + 1e-10)
                        meta_lr = .25 / max(dlr2 - dlr,  1e-10)
                    else:
                        meta_lr = group["meta_lr"]

                    if adjust:
                        change = 1 + meta_lr * dlr
                        param_state['lr'] *= min(change, self.bound)
                    else:
                        param_state['lr'] += meta_lr * dlr
                        if param_state['lr'] < 1e-6:
                            param_state['lr'] = 1e-6

            if level == 0:

                if adjust:
                    #meta_lr = .25 / (max((dlr2 - dlr),  dlr) + 1e-10)
                    meta_lr = .25 / max(dlr2 - dlr,  1e-10)
                else:
                    meta_lr = group["meta_lr"]

                if adjust:
                    change = 1 + meta_lr * dlr
                    group['lr'] *= min(change, self.bound)
                else:
                    group['lr'] += meta_lr * dlr
                    if group['lr'] < 1e-6:
                        group['lr'] = 1e-6

    def get_lr(self):

        lrs = []
        for group in self.param_groups:
            if self.meta_level == 0:
                lrs += [group["lr"]]
            elif self.meta_level == 1:
                tlrs = []
                for p in group['params']:
                    if p.grad is None:
                        continue
                    param_state = self.state[p]
                    tlrs += [param_state['lr']]
                lrs += [tlrs]

        return lrs
