#!/usr/bin/env python
# -*- coding: utf-8 -*-

import functools
from functools import partial
import inspect
import itertools

import jax
import jax.numpy as jnp
from jax._src.util import safe_zip, safe_map, unzip2
from jax import tree_util
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, register_pytree_node
from jax import jvp, linearize, jacrev, vmap, hessian, grad

import optimizers
from optimizers import OptimizerState


def transform(ty_args, ty_res):
    def check_same_tree(trees, names):
        # assert all elements equal
        first, rest = trees[0], trees[1:]
        first_name, rest_names = names[0], names[1:]
        for tree, tree_name in zip(rest, rest_names):
            if tree != first:
                msg = "Got dismatch parameter tree: {} tree {} and {} tree {}."
                raise TypeError(msg.format(first_name, first, tree_name, tree))
        return first

    def check_args(fun, args):
        spec = inspect.getfullargspec(fun)
        if spec.varargs is None:
            all_names = spec.args
            assert len(args) == len(
                all_names
            ), "optimizer got wrong number of parameters. Expect {} total {} parameters but got {}.".format(
                str(all_names), len(all_names), len(args)
            )
            return all_names
        else:
            all_names = [spec.varargs] * len(args)
            return all_names

    def pack_result(res, ty_res, tree):
        if ty_res == "State":
            states_flat, subtrees = unzip2(map(tree_flatten, res))
            return OptimizerState(states_flat, tree, subtrees)
        elif ty_res == "Tree":
            return tree_unflatten(tree, res)
        elif ty_res == "Reduce_Sum":
            return sum(res)
        elif isinstance(ty_res, Iterable):
            return [
                pack_result([r[i] for r in res], n_ty_res, tree)
                for i, n_ty_res in enumerate(ty_res)
            ]
        else:
            raise TypeError("optimizer got unknown result type {}.".format(typ_arg))

    def inner(opt_fun):
        @functools.wraps(opt_fun)
        def transformed_opt_fun(*args):
            all_names = check_args(opt_fun, args)

            def extract_tree(arg, ty_arg):
                if ty_arg == "Tree":
                    arg_flat, tree = tree_flatten(arg)
                    return tree, len(arg_flat)
                elif ty_arg == "State":
                    states_flat, tree, subtrees = arg
                    return tree, len(subtrees)
                else:
                    raise TypeError("Cannot extract tree for type {}.".format(ty_arg))

            trees, leafnums, tree_names = zip(
                *[
                    (*extract_tree(arg, ty_arg), name)
                    for (arg, ty_arg, name) in zip(args, ty_args, all_names)
                    if ty_arg in ["Tree", "State"]
                ]
            )
            tree = check_same_tree(trees, tree_names)
            leafnum = leafnums[0]

            def to_seq(arg, ty_arg):
                if ty_arg == "raw":
                    return [arg] * leafnum
                elif ty_arg == "Tree":
                    arg_flat, tree = tree_flatten(arg)
                    return arg_flat
                elif ty_arg == "State":
                    states_flat, tree, subtrees = arg
                    states = map(tree_unflatten, subtrees, states_flat)
                    return states
                else:
                    raise TypeError(
                        "optimizer got unknown argument type {}.".format(typ_arg)
                    )

            res = [
                opt_fun(*real_args) for real_args in zip(*map(to_seq, args, ty_args))
            ]
            return pack_result(res, ty_res, tree)

        return transformed_opt_fun

    return inner


def mean_f(f, subkeys, **kwargs):
    results = vec_f(f, subkeys, **kwargs)
    return results.mean(0)


def vec_f(f, subkeys, loop=False, batch_num=None, batch_size=None):
    if loop == True:
        if batch_num is None and batch_size is None:
            #  results = np.array([f(k) for k in subkeys])
            def p(_state, k):
                return (), f(k)

            _, results = jax.lax.scan(p, (), subkeys)
        else:
            vf = vmap(f)
            #  results = np.concatenate(
            #      [vf(sks) for sks in np.split(subkeys, subkeys.shape[0] // batch_num)]
            #  )
            def p(_state, sks):
                return (), vf(sks)

            if batch_num is not None:
                new_shape = (batch_num, -1, *subkeys.shape[1:])
            elif batch_size is not None:
                new_shape = (-1, batch_size, *subkeys.shape[1:])
            _, ls = jax.lax.scan(
                p,
                (),
                jnp.reshape(subkeys, new_shape),
            )
            #  results = np.concatenate(ls)
            results = transform(["Tree"], "Tree")(lambda l: jnp.concatenate(l))(ls)
    if loop == False:
        results = vmap(f)(subkeys)
    return results
