import jax, jax.numpy as jnp
import jax.tree_util as jtu
from typing import NamedTuple
from optax import tree_utils as otu

def kwrap(fn):
    def wrap_fn(*args, **kwargs):
        res = fn(*args, **kwargs)
        assert isinstance(res, tuple)
        return KahanState(*res)

    return wrap_fn

@kwrap
def kahan_init():
    sum = None
    c = None
    return (sum, c)

@kwrap
def kahan_mul(tup, val):
    return tuple(v * val if v is not None else None for v in tup)

@kwrap
def kahan_div(tup, val):
    return tuple(v / val if v is not None else None for v in tup)

IGNORE_KAHAN = False

@kwrap
def kahan_add(tup, val):
    sum, c = tup
    if IGNORE_KAHAN:
        return (sum + val, c)

    if sum is None:
        sum = val
        return (sum, c)
        
    t =  jax._src.ad_checkpoint._optimization_barrier(sum + val)

    # TODO select is elementwise, come back to and check this later
    c2 = jax.lax.select(jnp.abs(sum) >= jnp.abs(val),
        jax._src.ad_checkpoint._optimization_barrier(sum - t) + val,
        jax._src.ad_checkpoint._optimization_barrier(val - t) + sum
    )
    if c is None:
        c = c2
    else:
        c += c2
    
    sum = t
    return sum, c

def kahan_finish(tup):
    sum, c = tup
    if c is not None:
        sum += c
    return sum

def kahan(iter):
    sum = kahan_init()
    for val in iter:
        sum = kahan_add(sum, val)
    return kahan_finish(sum)

def kahan_from_sum(t):
    t_c = jtu.tree_map(lambda _: None, t)
    return KahanState(t, t_c)

def kahan_finish_tree(x):
    def finish_tup(sum, c):
        if sum.dtype == jax.float0:
            return sum
        return kahan_finish((sum, c))

    return jtu.tree_map(finish_tup, x.sum, x.c)

class KahanState(NamedTuple):
    sum: object
    c: object

# class KahanState:
#     def __init__(self, *args, **kw):
#         raise NotImplementedError

# def kahan_zeros_like(t):
#     clone = otu.tree_zeros_like(t)
#     if isinstance(t, KahanState):
#         return clone
#     else:
#         return kahan_from_sum(clone)

def outer_kahany_map(f, tree, *rest, is_leaf=None):
    leaves, treedef = jax.tree_util.tree_flatten(tree, is_leaf)
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    full_results = [f(*xs) for xs in zip(*all_leaves)]
    results = [r[0] for r in full_results]
    first_result, signature = full_results[0]
    num_trees = len(first_result)

    # num_trees = 1 + len(r for r in rest if not r is None)
    ret = [treedef.unflatten([res[i] for res in results]) for i in range(num_trees)]
    return ret, signature

def _flatten_kahans(args):
    signature = []
    leaves = []
    # assert any(isinstance(arg, KahanState) for arg in args) or len(args) == 0
    for arg in args:
        if isinstance(arg, KahanState):
            leaves.extend([arg.sum, arg.c])
            assert not isinstance(arg.sum, KahanState)
            assert not isinstance(arg.c, KahanState)
            signature.append('k')
        else:
            leaves.append(arg)
            signature.append('v')

    if 'k' in signature:
        assert len(signature) < len(leaves), (len(signature), len(leaves))

    return leaves, signature

def _unflatten_kahans(leaves, signature):
    if 'k' in signature:
        assert len(signature) < len(leaves), (len(signature), len(leaves))

    out = []
    i = 0
    for s in signature:
        if s == 'k':
            out.append(KahanState(leaves[i], leaves[i + 1]))
            i += 2
        elif s == 'v':
            out.append(leaves[i])
            i += 1
        else:
            raise ValueError(f"Unknown signature {s}")

    return out

def flatten_kahans(args, kwargs):
    args_flatten, args_sig = _flatten_kahans(args)
    kwargs_flatten, kwargs_sig = _flatten_kahans(list(kwargs.values()))

    num_args = len(args_flatten)
    flattened = args_flatten + kwargs_flatten
    keys = list(kwargs.keys())
    # now flatten args and kwargs

    return flattened, (args_sig, kwargs_sig, num_args), keys

def unflatten_kahans(flattened, signatures, keys):
    # args_flatten, kwargs_flatten = leaves
    args_sig, kwargs_sig, num_args = signatures
    args_flatten = flattened[:num_args]
    kwargs_flatten = flattened[num_args:]

    args_unflat = _unflatten_kahans(args_flatten, args_sig)
    kwargs_unflat = _unflatten_kahans(kwargs_flatten, kwargs_sig)

    kwargs_dict = dict(zip(keys, kwargs_unflat))
    return args_unflat, kwargs_dict

def kahan_treemap(f, *args, **kwargs):
    flat_kahan, signatures, keys = flatten_kahans(args, kwargs)

    def inner_map(*flat_kahan):
        for k in flat_kahan:
            assert not isinstance(k, KahanState)

        kargs, kkwargs = unflatten_kahans(flat_kahan, signatures, keys)

        res = f(*kargs, **kkwargs)
        res_flat, res_sig = _flatten_kahans(res)
        # nargs = len(res_flat)
        return res_flat, res_sig

    res, signature = outer_kahany_map(inner_map, *flat_kahan)
    return _unflatten_kahans(res, signature)
