from typing import NamedTuple
from functools import cache
from optax import tree_utils as otu
from optax._src import utils as optax_utils
import jax.numpy as jnp
from optax._src import numerics
from jax import tree_util as jtu
from flax import struct
import jax
from functools import partial
from .kahantools import kahan_from_sum, kahan_treemap
from .kahantools import kahan_add, kahan_mul, kahan_div, kahan_finish
from .kahantools import kahan_finish_tree, KahanState
from .old_optimizers import old_adam_step
from typing import Any

class LoganOptimizer:
    def __init__(self, *args, **kwargs):
        '''
        set up the optimizer with whatever desired hparams
        '''
        raise NotImplementedError

    def opt_state_to_kahan(self, params):
        '''
        return initial_opt_state
        '''
        raise NotImplementedError

    def kahan_to_opt_state(self, params, state, gradient):
        '''
        return new_params, new_opt_state
        '''
        raise NotImplementedError

    def initial_state(self, params):
        '''
        return initial_state
        '''
        raise NotImplementedError

    def auto_step(self, state, gradient):
        '''
        return new_params, new_opt_state
        '''
        raise NotImplementedError

    def kahan_dual_and_primal(self, state, dstate, updates, dupdates,
                              calculate_dual):
        '''
        return new_params, new_opt_state
        '''
        raise NotImplementedError

    def infer_gradient_from(self, prev_state, next_state):
        raise NotImplementedError

    def trim(self):
        raise NotImplementedError

class AdamState(NamedTuple):
    count: object
    mu: object
    nu: object

def safe_zeros_like(x, make_tangent):
    def mapper(x):
        if x is None:
            return None

        if x.dtype == jax.float0:
            return x

        if make_tangent:
            if x.dtype in [jnp.int32, jnp.int64]:
                return None

        return jnp.zeros_like(x)

    return jtu.tree_map(mapper, x)

@cache
def get_one():
    return jnp.zeros((), device=jax.devices('gpu')[0])

class TrainState(struct.PyTreeNode):
    params: object = struct.field(pytree_node=True)
    opt_state: object = struct.field(pytree_node=True)
    # optimizer: object = struct.field(pytree_node=False)
    optimizer: object = struct.field(pytree_node=True)
    batch_stats: Any

    def zeros_like(self, is_tangent):
        safe_opt_state = safe_zeros_like(self.opt_state, is_tangent)
        safe_params = safe_zeros_like(self.params, is_tangent)
        return self.replace(params=safe_params,
                          opt_state=safe_opt_state,
                          optimizer=self.optimizer)

    def apply_grads_kahan(self, grads):
        assert self.kahan
        kw = {
            'state': self,
            'dstate': None, 
            'updates': grads,
            'dupdates': None,
            'calculate_dual': False
        }

        return self.optimizer.kahan_dual_and_primal(**kw)

    def apply_grads_auto(self, state, grads, lr_factor, wd_factor):
        assert not self.kahan
        assert not isinstance(state.params, KahanState)
        assert isinstance(state, TrainState)
        res = self.optimizer.auto_step(state, grads, lr_factor=lr_factor,
                                       wd_factor=wd_factor)
        assert isinstance(res, TrainState)
        assert res.optimizer == self.optimizer
        assert not res.opt_state.count is None
        return res

    def jvp_kahan(self, *, grads, dgrads, dstate):
        assert self.kahan
        assert isinstance(self.params, KahanState)
        assert dstate is None or isinstance(dstate.params, KahanState)

        if dstate is None:
            dstate = self.zeros_like(is_tangent=True)

        kw = {
            'state': self,
            'dstate': dstate,
            'updates': grads,
            'dupdates': dgrads,
            'calculate_dual': True
        }

        return self.optimizer.kahan_dual_and_primal(**kw)

    def trim(self):
        return self.opt_state.mu

    def infer_gradient_from(self, next_state):
        return self.optimizer.infer_gradient_from(self, next_state)
    
    def jvp_auto(self, *, grads, dgrads, dstate):
        assert not self.kahan
        assert not isinstance(self.params, KahanState)
        assert dstate is None or not isinstance(dstate.params, KahanState)
        step = self.apply_grads_auto

        if dstate is None:
            step = partial(step, self)
            out, dout = jax.jvp(step, (grads,), (dgrads,))
        else:
            out, dout = jax.jvp(step, (self, grads), (dstate, dgrads))

        # assert that jvp is the same
        assert isinstance(dout, TrainState)
        return out, dout

    def apply_grads(self, grads, lr_factor=get_one(), wd_factor=get_one()):
        if self.kahan:
            raise NotImplementedError
            return self.apply_grads_kahan(grads)

        return self.apply_grads_auto(self, grads, lr_factor=lr_factor,
                                     wd_factor=wd_factor)

    def jvp(self, grads, dgrads, dstate):
        if self.kahan:
            res = self.jvp_kahan(grads=grads, dgrads=dgrads,
                                  dstate=dstate)
        else:
            res = self.jvp_auto(grads=grads, dgrads=dgrads, dstate=dstate)

        primal, tangent = res
        assert primal.opt_state.count is not None
        tangent_count = tangent.opt_state.count
        assert tangent_count is None or tangent_count.dtype == jax.float0
        return primal, tangent

    def kahanify(self, force=False):
        if self.kahan:
            if not force:
                raise ValueError('WARNING - already in kahan mode')

            return self

        assert not isinstance(self.params, KahanState)
        kahan_params = kahan_from_sum(self.params)
        opt_state = self.optimizer.opt_state_to_kahan(self.opt_state)
        return self.replace(params=kahan_params, opt_state=opt_state,
                            optimizer=self.optimizer)

    @property
    def kahan(self):
        return isinstance(self.params, KahanState)

    def unkahanify(self, force=False):
        kahan = isinstance(self.params, KahanState)
        if not kahan:
            if not force:
                raise ValueError('WARNING - already in non-kahan mode')

            return self

        assert isinstance(self.params, KahanState)
        params = kahan_finish_tree(self.params)
        opt_state = self.optimizer.kahan_to_opt_state(self.opt_state)
        return self.replace(params=params, opt_state=opt_state,
                            optimizer=self.optimizer)

    @staticmethod
    def create(optimizer, params, kahan, batch_stats):
        state_0 = optimizer.initial_state(params)
        assert not isinstance(params, KahanState)

        if kahan:
            params = kahan_from_sum(params)
            state_0 = optimizer.opt_state_to_kahan(state_0)

        return TrainState(params=params, opt_state=state_0,
                          optimizer=optimizer, batch_stats=batch_stats)

# @jax.tree_util.register_pytree_node_class
class AdamOptimizer(struct.PyTreeNode):
    lr: object = struct.field(pytree_node=False)
    wd: object = struct.field(pytree_node=True)
    b1: object = struct.field(pytree_node=True)
    b2: object = struct.field(pytree_node=True)
    eps: object = struct.field(pytree_node=True)
    eps_root: object = struct.field(pytree_node=True)
    selective_wd: object = struct.field(pytree_node=True)
    factored_lr_wd: object = struct.field(pytree_node=True)
    max_lr: object = struct.field(pytree_node=True)

    """
    def __init__(self, lr, wd, b1, b2, eps, eps_root, selective_wd,
                 factored_lr_wd, max_lr):
        self.lr = lr
        self.wd = wd
        self.b1 = b1
        self.b2 = b2
        self.eps = eps
        self.eps_root = eps_root
        self.selective_wd = selective_wd
        self.factored_lr_wd = factored_lr_wd
        self.max_lr = max_lr
    """

    def initial_state(self, params):
        mu = otu.tree_zeros_like(params)
        nu = otu.tree_zeros_like(params)
        count = jnp.zeros([], jnp.int32)
        return AdamState(count=count, mu=mu, nu=nu)
    """
    def tree_flatten(self):
        return (self.wd, self.b1, self.b2, self.eps, self.eps_root,
                self.selective_wd, self.factored_lr_wd, self.max_lr), (self.lr,)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(aux_data[0], *children)
    """

    def auto_step(self, state, updates, lr_factor, wd_factor):
        print('>> New LR is ignored for Adam')
        assert not isinstance(state.params, KahanState)
        assert not isinstance(state.opt_state.mu, KahanState)
        assert not isinstance(state.opt_state.nu, KahanState)
        assert not isinstance(updates, KahanState)
        ret = old_adam_step(self, state.params, state.opt_state, updates,
                             AdamState, lr_factor=lr_factor,
                             wd_factor=wd_factor)
        new_params, new_adam_state = ret
        return state.replace(params=new_params, opt_state=new_adam_state,
                             optimizer=self)

    def infer_gradient_from(self, prev_state, next_state):
        this_mu = prev_state.opt_state.mu
        if hasattr(next_state, 'opt_state'):
            next_mu = next_state.opt_state.mu
        else:
            next_mu = next_state

        b1 = self.b1
        grads = recover_grads(this_mu, next_mu, b1)
        return grads
        
@jax.jit
def recover_grads(this_mu, next_mu, b1):
    def safe_map(x, y):
        if x is None or x.dtype == jax.float0:
            return x

        if y is None or y.dtype == jax.float0:
            return y

        return (x - y * b1)/(1 - b1)

    return jax.tree_util.tree_map(safe_map, next_mu, this_mu)