import contextlib
import functools
import inspect
import re
import threading
from functools import partial as bind

import jax
import jax.numpy as jnp
import equinox as eqx

__version__ = "0.9.0"


###############################################################################
# State
###############################################################################


# When running an impure function that accesses state, it will find the state
# in this global variable. The pure() wrapper populates this global variable
# with the provided state, calls the inner function, and then the takes the
# resulting state out of the global variable to return it back to the user.
# To allow multi-threaded programs to use impure functions in parallel, the
# context is a dictionary with a slot for each thread identifier.
CONTEXT = {}


class Context(dict):

    def __init__(self, entries, rng, create, modify, ignore, reserve, name):
        super().__init__(entries)
        self.create = create  # Allow creating new state entries.
        self.modify = modify  # Allow modifying existing state entries.
        self.ignore = ignore  # Ignore modifications to existing state entries.
        self.rng = rng
        self.reserve = reserve
        self.name = name

    def update(self, entries):
        for key, value in dict(entries).items():
            self[key] = value

    def __setitem__(self, key, value):
        if not self.modify:
            raise RuntimeError(
                "Cannot modify state entries here. If you want to modify "
                "state inside of scan() set modify=True. "
                + f"You were trying to set {key} to shape {value.shape} and "
                + f"dtype {value.dtype}."
            )
        if self.ignore and key in self:
            return  # Do not overwrite existing entries.
        if not self.create and key not in self:
            raise RuntimeError(
                "Can only create state entries during first call. "
                + f"You were trying to set {key} to shape {value.shape} and "
                + f"dtype {value.dtype}."
            )
        super().__setitem__(key, value)


def pure(fun, nested=False):
    """Wrap an impure function that uses global state to explicitly pass the
    state in and out. The result is a pure function that is composable with JAX
    transformation. The pure function can be used as follows:
    `out, state = fun(state, rng, *args, **kwargs)`."""

    def purified(state, rng, *args, create=None, modify=None, ignore=None, **kwargs):
        context = CONTEXT.get(threading.get_ident(), None)
        if context:
            create = create if create is not None else context.create
            modify = modify if modify is not None else context.modify
            ignore = ignore if ignore is not None else context.ignore
            assert context.create or not create, "Parent context disabled create."
            assert context.modify or not modify, "Parent context disabled modify."
            assert not context.ignore or ignore, "Parent context enabled ignore."
        else:
            create = create if create is not None else True
            modify = modify if modify is not None else True
            ignore = ignore if ignore is not None else False
        if not isinstance(state, dict):
            raise ValueError("Must provide a dict as state.")
        if context and (not nested):
            raise RuntimeError(
                f"You are trying to call pure {fun.__name__}() inside pure "
                f"{context.name}(). Is that intentional? If you want to nest pure "
                f"functions, use pure(..., nested=True) for the inner function."
            )
            # raise RuntimeError(
            #     f'If you want to nest run() calls, use nested=True. ({context})')
        before = context
        try:
            name = fun.__name__
            if rng.shape == ():
                rng = jax.random.PRNGKey(rng)
            context = Context(state.copy(), rng, create, modify, ignore, [], name)
            CONTEXT[threading.get_ident()] = context
            out = fun(*args, **kwargs)
            state = dict(context)
            return out, state
        finally:
            CONTEXT[threading.get_ident()] = before

    purified.pure = True
    return purified


def context():
    """Access and modify the global context from within an impure function. For
    advanced users only. Prefer to use module methods to access and modify state
    and rng() to get the next RNG key."""
    context = CONTEXT.get(threading.get_ident(), None)
    if context is None:
        raise RuntimeError("Wrap impure functions in pure() before running them.")
    return context


@jax.named_scope("rng")
def rng(amount=None, reserve=16):
    """Split the global RNG key and return a new local key."""
    ctx = context()
    if amount:
        keys = jax.random.split(ctx.rng, amount + 1)
        ctx.rng = keys[0]
        return keys[1:]
    else:
        if not ctx.reserve:
            keys = jax.random.split(ctx.rng, reserve)
            ctx.rng = keys[0]
            ctx.reserve = list(keys[1:])
        return ctx.reserve.pop(0)


def creating():
    """Indicates whether the program is currently allowed to create state
    entries. Can use used for initialization logic that should be excluded from
    compiled functions."""
    return context().create


###############################################################################
# Transformations
###############################################################################


@jax.named_scope("grad")
def grad(fun, keys, has_aux=False):
    """Compute the gradient of an impure function with respect to the specified
    state entries or modules. The transformed function returns a tuple containing
    the computed value, selected state entries, their gradients, and if
    applicable auxiliary outputs of the function."""
    keys = keys if hasattr(keys, "__len__") else (keys,)
    if getattr(fun, "pure", False):
        raise ValueError("Use plain jax.grad() for pure functions.")
    if not has_aux:
        fun = lambda *args, _fun=fun, **kwargs: (_fun(*args, *kwargs), {})
    fun = pure(fun, nested=True)

    def forward(x1, x2, rng, *args, **kwargs):
        (y, aux), state = fun({**x1, **x2}, rng, *args, create=False, **kwargs)
        return y, (aux, state)

    backward = jax.value_and_grad(forward, has_aux=True)

    @functools.wraps(backward)
    def wrapper(*args, **kwargs):
        _prerun(fun, *args, **kwargs)
        assert all(isinstance(x, (str, Module)) for x in keys)
        strs = [x for x in keys if isinstance(x, str)]
        mods = [x for x in keys if isinstance(x, Module)]
        for mod in mods:
            strs += mod.getm()
        x1 = {k: v for k, v in context().items() if k in strs}
        x2 = {k: v for k, v in context().items() if k not in strs}
        (y, (aux, state)), dx = backward(x1, x2, rng(), *args, **kwargs)
        context().update(state)
        return (y, x1, dx, aux) if has_aux else (y, x1, dx)

    return wrapper


def jit(fun, static=None, **kwargs):
    """Compiles a pure function for fast execution. Only the first call of the
    function is allowed to create state entries."""
    if not getattr(fun, "pure", False):
        raise ValueError("Use pure() before applying jit().")
    static = static or ()

    @bind(jax.jit, static_argnums=[0], **kwargs)
    def init(statics, rng, *args, **kw):
        # Return only state so JIT can remove dead code for fast initialization.
        s = fun({}, rng, *args, ignore=True, **dict(statics), **kw)[1]
        return s

    @bind(jax.jit, static_argnums=[0], **kwargs)
    def apply(statics, state, rng, *args, **kw):
        return fun(state, rng, *args, create=False, **dict(statics), **kw)

    @functools.wraps(fun)
    def wrapper(state, rng, *args, init_only=False, **kw):
        if any([name not in kw for name in static]):
            raise ValueError("Please pass all static arguments by keyword.")
        state = state.copy()
        statics = tuple(sorted([(k, v) for k, v in kw.items() if k in static]))
        kw = {k: v for k, v in kw.items() if k not in static}
        if not hasattr(wrapper, "keys"):
            created = init(statics, rng, *args, **kw)
            wrapper.keys = set(created.keys())
            for key, value in created.items():
                if key not in state:
                    state[key] = value
        if init_only:
            return state
        else:
            selected = {k: v for k, v in state.items() if k in wrapper.keys}
            out, updated = apply(statics, selected, rng, *args, **kw)
            return out, {**state, **updated}

    return wrapper


def pmap(fun, axis_name=None, static=None, **kwargs):
    """Compiles n pure function for fast execution across multiple devices. Only
    the first call of the function is allowed to create state entries."""
    if not getattr(fun, "pure", False):
        raise ValueError("Use pure() before applying jit().")
    static = static or ()

    @bind(jax.pmap, axis_name=axis_name, static_broadcasted_argnums=[0], **kwargs)
    def init(statics, rng, *args, **kw):
        # Return only state so JIT can remove dead code for fast initialization.
        return fun({}, rng, *args, ignore=True, **dict(statics), **kw)[1]

    @bind(jax.pmap, axis_name=axis_name, static_broadcasted_argnums=[0], **kwargs)
    def apply(statics, state, rng, *args, **kw):
        return fun(state, rng, *args, create=False, **dict(statics), **kw)

    @functools.wraps(fun)
    def wrapper(state, rng, *args, init_only=False, **kw):
        if any([name not in kw for name in static]):
            raise ValueError("Please pass all static arguments by keyword.")
        state = state.copy()
        statics = tuple(sorted([(k, v) for k, v in kw.items() if k in static]))
        kw = {k: v for k, v in kw.items() if k not in static}
        if not hasattr(wrapper, "keys"):
            created = init(statics, rng, *args, **kw)
            wrapper.keys = set(created.keys())
            for key, value in created.items():
                if key not in state:
                    state[key] = value
        if init_only:
            return state
        else:
            selected = {k: v for k, v in state.items() if k in wrapper.keys}
            out, updated = apply(statics, selected, rng, *args, **kw)
            return out, {**state, **updated}

    return wrapper


@jax.named_scope("cond")
def cond(pred, true_fun, false_fun, *operands):
    true_fun = pure(true_fun, nested=True)
    false_fun = pure(false_fun, nested=True)
    _prerun(true_fun, *operands)
    _prerun(false_fun, *operands)
    out, state = jax.lax.cond(
        pred,
        lambda state, rng1, rng2, *args: true_fun(state, rng1, *args),
        lambda state, rng1, rng2, *args: false_fun(state, rng2, *args),
        dict(context()),
        *rng(2),
        *operands,
    )
    context().update(state)
    return out


@jax.named_scope("scan")
def scan(fun, carry, xs, reverse=False, unroll=1, modify=False):
    fun = pure(fun, nested=True)
    _prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs))
    length = len(jax.tree_util.tree_leaves(xs)[0])
    rngs = rng(length)
    if modify:

        def inner(carry, x):
            carry, state = carry
            x, rng = x
            (carry, y), state = fun(state, rng, carry, x, create=False)
            return (carry, state), y

        (carry, state), ys = jax.lax.scan(
            inner, (carry, dict(context())), (xs, rngs), length, reverse, unroll
        )
        context().update(state)
    else:

        def inner(carry, x):
            x, rng = x
            (carry, y), state = fun(
                dict(context()), rng, carry, x, create=False, modify=False
            )
            return carry, y

        carry, ys = jax.lax.scan(inner, carry, (xs, rngs), length, reverse, unroll)
    return carry, ys


@jax.named_scope("_prerun")
def _prerun(fun, *args, **kwargs):
    if not context().create:
        return
    discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)
    # jax.tree_util.tree_map(
    #     lambda x: hasattr(x, 'delete') and x.delete(), discarded)
    context().update(state)


###############################################################################
# Modules
###############################################################################


SCOPE = ""


@contextlib.contextmanager
def scope(name, absolute=False):
    """Enter a relative or absolute name scope. Name scopes are used to make
    names of state entries unique."""
    global SCOPE
    if SCOPE is None:
        raise RuntimeError(
            "Purify stateful functions with fn = pure(fn) before running them."
        )
    outside = SCOPE
    if absolute:
        SCOPE = name
    elif SCOPE == "":
        SCOPE = name
    else:
        SCOPE = outside + "/" + name
    yield SCOPE
    SCOPE = outside


class ModuleMeta(type):
    """Meta class that creates a unique path for each module instance and wraps
    the methods and properties of the module to enter the name scope."""

    def __new__(mcs, name, bases, clsdict):
        """This runs once per user module class definition. It wraps the methods of
        the module class to automatically enter the name scope of the module."""
        method_names = []
        for key, value in clsdict.items():
            if key.startswith("__") and key != "__call__":
                continue
            elif isinstance(value, property):
                clsdict[key] = property(
                    value.fget if not value.fget else _scope_method(value.fget),
                    value.fset if not value.fset else _scope_method(value.fset),
                    value.fdel if not value.fdel else _scope_method(value.fdel),
                    doc=value.__doc__,
                )
            elif inspect.isfunction(value):
                method_names.append(key)
        cls = super(ModuleMeta, mcs).__new__(mcs, name, bases, clsdict)
        for method_name in method_names:
            method = getattr(cls, method_name)
            method = _scope_method(method)
            setattr(cls, method_name, method)
        return cls

    def __call__(cls, *args, name=None, **kwargs):
        """This runs once per use module instance creation. It derives a unique
        name and path for the module instance."""
        if not isinstance(name, str):
            raise KeyError(
                "Please provide a module name via Module(..., name='example')."
            )
        if not re.match(r"[A-Za-z0-9_]+", name):
            raise KeyError(
                "Only letters, numbers, and underscores are allowed in scope names."
            )
        obj = cls.__new__(cls)
        with scope(name) as path:
            obj._path = path
        obj._submodules = {}
        init = _scope_method(cls.__init__)
        init(obj, *args, **kwargs)
        return obj


def _scope_method(method):
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        with scope(self._path, absolute=True):
            with jax.named_scope(self._path.split("/")[-1]):
                return method(self, *args, **kwargs)

    return wrapper


class Module(object, metaclass=ModuleMeta):
    """Base class for users to inherit their modules from. Provides automatic
    name scoping via the meta class and helper functions for accessing state."""

    def __repr__(self):
        return f"{self.__class__.__name__}({self.path})"

    @property
    def path(self):
        """The unique name scope of this module instance as a string."""
        return self._path

    @property
    def name(self):
        """The name of this module instance as a string."""
        return self._path.split("/")[-1]

    def get(self, name, *args, **kwargs):
        """Retrieve or create a state entry that belongs to this module."""
        path = self.path + "/" + name
        if name in self._submodules:
            return self._submodules[name]
        if path in context():
            return context()[path]
        ctor, *args = args
        if "name" in inspect.signature(ctor).parameters:
            kwargs["name"] = name
        value = ctor(*args, **kwargs)
        # flat, _ = jax.tree_util.tree_flatten(value)
        filtered_value = eqx.filter(value, eqx.is_array)
        if filtered_value is not None:
            context()[path] = filtered_value
        # if all(isinstance(x, jnp.ndarray) for x in flat):
        #  context()[path] = value
        else:
            self._submodules[name] = value
        return value

    def put(self, name, value):
        """Update or create a single state entry that belongs to this module."""
        self.putm({self.path + "/" + name: value})
        return value

    def getm(self, pattern=r".*", allow_empty=True):
        """Read the state entries of this module, optionally filtered by regex."""
        pattern = re.compile(pattern)
        prefix = self.path + "/"
        results = {}
        for key, value in context().items():
            if not key.startswith(prefix):
                continue
            if pattern.match(key[len(prefix) :]):
                results[key] = value
        if not allow_empty and not results:
            raise KeyError(f"Pattern {pattern} matched no state keys.")
        return results

    def putm(self, mapping):
        """Update or create multiple state entries that belong to this module."""
        prefix = self.path + "/"
        for key in mapping:
            if not key.startswith(prefix):
                raise KeyError(f"Key {key} does not belong to module {self.path}.")
        context().update(mapping)


class Variable(Module):

    def __init__(self, ctor, *args, **kwargs):
        self.ctor = ctor
        self.args = args
        self.kwargs = kwargs

    def read(self):
        return self.get("value", self.ctor, *self.args, **self.kwargs)

    def write(self, value):
        return self.put("value", value)


###############################################################################
# Integrations
###############################################################################


class HaikuModule(Module):

    def __init__(self, ctor, *args, **kwargs):
        import haiku as hk

        def net(*args_, **kwargs_):
            return ctor(*args, **kwargs)(*args_, **kwargs_)

        self.transformed = hk.transform(net)

    def __call__(self, *args, **kwargs):
        state = self.get("state", self.transformed.init, rng(), *args, **kwargs)
        return self.transformed.apply(state, rng(), *args, **kwargs)


class FlaxModule(Module):

    def __init__(self, ctor, *args, **kwargs):
        self.module = ctor(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        state = self.get("state", self.module.init, rng(), *args, **kwargs)
        return self.module.apply(state, *args, **kwargs)


def init_fn(module, args, kwargs):
    model = module(*args, **kwargs)
    # return eqx.partition(model, eqx.is_inexact_array)[0]
    return model


class EquinoxModule(Module):

    def __init__(self, ctor, *args, **kwargs):
        self.module = ctor
        self.args = args
        self.kwargs = kwargs

    def __call__(self, *args, **kwargs):
        params = self.get("state", init_fn, self.module, self.args, self.kwargs)
        return params(*args, **kwargs)


class ESCNNModule(Module):
    def __init__(self, ctor, *args, **kwargs):
        if "net" in kwargs:
            self.module = kwargs["net"]
        else:
            self.module = ctor(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        def init_fn(module):
            params, _ = eqx.partition(module, eqx.is_inexact_array)
            return params

        params = self.get("state", init_fn, self.module)
        _, self.static = eqx.partition(self.module, eqx.is_inexact_array)
        self.module = eqx.combine(params, self.static)
        return self.module(*args, **kwargs)


class OptaxModule(Module):

    def __init__(self, ctor, *args, **kwargs):
        self.opt = ctor(*args, **kwargs)

    def __call__(self, loss, keys, *args, **kwargs):
        import optax

        loss, params, grads = grad(loss, keys)(*args, **kwargs)
        optstate = self.get("state", self.opt.init, params)
        updates, optstate = self.opt.update(grads, optstate)
        self.put("state", optstate)
        context().update(optax.apply_updates(params, updates))
        return {"loss": loss.mean(), "grad_norm": optax.global_norm(grads)}
