from functools import partial
from pprint import pformat

import jax
import numpy as onp
from jax import jit
from jax import numpy as np
from jax import tree_util

leaves = jax.tree_util.tree_leaves
unary_op = jax.tree_map
binary_op = jax.tree_map


def is_scalar(x):
    if isinstance(x, (int, float, bool)):
        return True
    if hasattr(x, "shape") and not x.shape:
        return True
    return False


@partial(jit, static_argnums=(0,))
def broadcasting_binary_op(op, a, b):
    if is_scalar(a):
        return unary_op(lambda x: op(a, x), b)
    elif is_scalar(b):
        return unary_op(lambda x: op(x, b), a)
    else:
        return binary_op(op, a, b)


add = jit(partial(broadcasting_binary_op, np.add))
multiply = jit(partial(broadcasting_binary_op, np.multiply))
subtract = jit(partial(broadcasting_binary_op, np.subtract))
divide = jit(partial(broadcasting_binary_op, np.divide))
logical_or = jit(partial(broadcasting_binary_op, np.logical_or))
logical_and = jit(partial(broadcasting_binary_op, np.logical_and))
logical_not = jit(partial(broadcasting_binary_op, np.logical_not))
logical_xor = jit(partial(broadcasting_binary_op, np.logical_xor))
less = jit(partial(broadcasting_binary_op, np.less))
less_equal = jit(partial(broadcasting_binary_op, np.less_equal))
greater = jit(partial(broadcasting_binary_op, np.greater))
greater_equal = jit(partial(broadcasting_binary_op, np.greater_equal))
equal = jit(partial(broadcasting_binary_op, np.equal))
not_equal = jit(partial(broadcasting_binary_op, np.not_equal))
floor_divide = jit(partial(broadcasting_binary_op, np.floor_divide))
mod = jit(partial(broadcasting_binary_op, np.mod))
power = jit(partial(broadcasting_binary_op, np.power))
left_shift = jit(partial(broadcasting_binary_op, np.left_shift))
right_shift = jit(partial(broadcasting_binary_op, np.right_shift))
vdot = jit(lambda a, b: sum(leaves(binary_op(np.vdot, a, b))))
normsq = jit(lambda a: vdot(a, a))
l2norm = jit(lambda a: np.sqrt(normsq(a)))
omin, omax = min, max
max = jit(lambda a: omax(leaves(unary_op(np.max, a))))
min = jit(lambda a: omin(leaves(unary_op(np.min, a))))
osum = sum
sum = jit(lambda a: osum(leaves(unary_op(np.sum, a))))

negative = jit(partial(unary_op, np.negative))
positive = jit(partial(unary_op, np.positive))
log = jit(partial(unary_op, np.log))
log10 = jit(partial(unary_op, np.log10))
log2 = jit(partial(unary_op, np.log2))
abs = jit(partial(unary_op, np.abs))
zeros = jit(partial(unary_op, lambda x: np.zeros(x.shape)))
ones = jit(partial(unary_op, lambda x: np.ones(x.shape)))
all = lambda x: tree_util.tree_all(unary_op(np.all, x))
_any = any
any = lambda x: _any(leaves(unary_op(np.any, x)))
to_onp = partial(unary_op, onp.array)
to_jnp = partial(unary_op, np.array)


@tree_util.register_pytree_node_class
class Params:
    def __init__(self, params):
        self.params = params

    def shapes(self):
        return unary_op(lambda x: x.shape, self.params)

    def __repr__(self):
        return f"Params(Shape:{self.shapes()})"

    def __str__(self):
        return f"Params(\n{pformat(self.shapes())}\n)"

    # Comparisons
    __eq__ = equal
    __ne__ = not_equal
    __lt__ = less
    __le__ = less_equal
    __gt__ = greater
    __ge__ = greater_equal

    # Maths
    __add__ = add
    __sub__ = subtract
    __mul__ = multiply
    __truediv__ = divide
    __floordiv__ = floor_divide
    __mod__ = mod
    __pow__ = power
    __lshift__ = left_shift
    __rshift__ = right_shift
    __and__ = logical_and
    __xor__ = logical_xor
    __or__ = logical_or

    # Unary maths
    __pos__ = positive
    __neg__ = negative
    __abs__ = abs

    def num_params(self):
        return osum(x.size for x in tree_util.tree_leaves(self))

    # Full flatten/unflatten for converting into 1D arrays
    def full_flatten(self):
        flat, tree = jax.tree_util.tree_flatten(self)
        shapes = [x.shape for x in flat]
        flat_flat = np.concatenate([x.reshape(-1) for x in flat])
        return flat_flat, (tree, shapes)

    @classmethod
    def full_unflatten(cls, vector, tree_shapes):
        tree, shapes = tree_shapes
        partial_flat = np.split(vector, onp.cumsum([onp.prod(s) for s in shapes[:-1]]))
        flat_change = [x.reshape(s) for x, s in zip(partial_flat, shapes)]
        return jax.tree_util.tree_unflatten(tree, flat_change)

    # Flatten/unflatten for registering with JAX
    def tree_flatten(self):
        return (self.params,), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)
