import jax
from jax import tree_util as jtu
from .optimizers import LoganOptimizer
from optax._src import numerics
from optax import tree_utils as otu
import jax.numpy as jnp
from typing import NamedTuple
from flax import struct

class SGDState(NamedTuple):
    count: object
    b: object

class SGDOptimizer(struct.PyTreeNode):
    lr: object = struct.field(pytree_node=False)
    peak_lr: object = struct.field(pytree_node=True)
    wd: object = struct.field(pytree_node=True)
    momentum: object = struct.field(pytree_node=False)
    should_exclude_tree: object = struct.field(pytree_node=True)
    nesterov: object = struct.field(pytree_node=True)
    schedule_momentum: object = struct.field(pytree_node=True)

    def initial_state(self, params):
        '''
        return initial_state
        '''
        return SGDState(count=jnp.zeros([], jnp.int32), 
                        b=otu.tree_zeros_like(params))

    def auto_step(self, state, gradient, lr_factor=0.0, wd_factor=1.0, **kwargs):
        '''
        return new_state with updated parameters and optimizer state
        '''
        lr = self.lr(state.opt_state.count) + lr_factor
        wd = self.wd

        # lr = self.lr(state.opt_state.count) # EDITED
        count_inc = numerics.safe_int32_increment(state.opt_state.count)
        # Weight decay update
        # gradient = jtu.tree_map(lambda g, w: g + self.wd * w, gradient, state.params)

        # Update parameters
        # this_mom = self.momentum(state.opt_state.count)
        jax.lax.cond(
            state.opt_state.count == 100,
            lambda _: jax.debug.print('{lr}', lr=lr),
            lambda _: None,
            operand=None
        )
        decay_rate = lr / self.peak_lr
        this_mom = jax.lax.cond(self.schedule_momentum,
                                lambda: self.momentum * (1 - decay_rate),
                                lambda: self.momentum)
        new_b = jtu.tree_map(lambda b, g: this_mom * b + g, state.opt_state.b, gradient)
        gradient = jax.lax.cond(self.nesterov,
                                lambda: jtu.tree_map(lambda g, b: g + this_mom * b, gradient, new_b),
                                lambda: new_b)

        new_params = jtu.tree_map(lambda p, q: jax.lax.cond(q, lambda: p, lambda: p * (1 - self.wd * decay_rate)), state.params, self.should_exclude_tree)

        new_params = jtu.tree_map(lambda p, g: p - lr * g, new_params, gradient)
        new_opt_state = SGDState(count=count_inc, b=new_b)
        return state.replace(params=new_params, opt_state=new_opt_state, optimizer=self)

    def infer_gradient_from(self, prev_state, next_state):
        # this_mom = self.momentum(prev_state.opt_state.count)
        lr = self.lr(prev_state.opt_state.count)
        decay_rate = lr / self.peak_lr
        this_mom = jax.lax.cond(self.schedule_momentum,
                                lambda: self.momentum * (1 - decay_rate),
                                lambda: self.momentum)
        return jtu.tree_map(lambda a, b: a - this_mom * b, 
                            next_state.opt_state.b, 
                            prev_state.opt_state.b)