"""
Contains diverse helper functionns that do not yet have
a canonical location in the code.

"""
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from jax import jit, vmap

def flat(unravel, argnum=0):
    def flatten(func):
        def flattened(*args, **kwargs):
            args = [arg for arg in args]
            args[argnum] = unravel(args[argnum])
            return func(*args, **kwargs)
        return flattened
    return flatten

# Can be used as a decorator:
# @flat(unravel, argnum=0)
# def test(params, x):
#    return params

# print(test(f_params, None))

def flatten_pytrees(pytree_1, pytree_2):
    f_pytree_1, unravel_1 = ravel_pytree(pytree_1)
    f_pytree_2, unravel_2 = ravel_pytree(pytree_2)

    len_1 = len(f_pytree_1)
    len_2 = len(f_pytree_2)
    flat = jnp.concatenate([f_pytree_1, f_pytree_2], axis=0)

    def retrieve_pytrees(flat):
        flat_1 = flat[0:len_1]
        flat_2 = flat[len_1:len_1+len_2]
        return unravel_1(flat_1), unravel_2(flat_2)

    return flat, retrieve_pytrees


# def flatten_pytrees(*pytrees):
#     flat_pytrees = []
#     unravels = []
#     lengths = []

#     # Flatten each pytree and store the unravel functions and lengths
#     for pytree in pytrees:
#         f_pytree, unravel = ravel_pytree(pytree)
#         flat_pytrees.append(f_pytree)
#         unravels.append(unravel)
#         lengths.append(len(f_pytree))

#     # Concatenate all flat representations into a single flat array
#     flat = jnp.concatenate(flat_pytrees, axis=0)

#     def retrieve_pytrees(flat):
#         # Starting index of the flat representation for each pytree
#         starts = jnp.cumsum(jnp.array([0] + lengths[:-1]))
#         # End index of the flat representation for each pytree
#         ends = starts + jnp.array(lengths)
#         # Unravel each section of the flat array back into a pytree
#         pytrees = [unravels[i](flat[starts[i]:ends[i]]) for i in range(len(lengths))]
#         return pytrees

#     return flat, retrieve_pytrees


def flatten_pytrees(*pytrees):
    flat_pytrees = []
    static_unravels = []
    total_length = 0

    for pytree in pytrees:
        f_pytree, unravel = ravel_pytree(pytree)
        flat_pytrees.append(f_pytree)
        start_index = total_length
        end_index = start_index + len(f_pytree)
        total_length += len(f_pytree)

        # Create a static unravel function for this pytree
        def make_static_unravel(start, end, unravel_fn):
            def static_unravel(flat):
                return unravel_fn(flat[start:end])
            return static_unravel

        static_unravels.append(make_static_unravel(start_index, end_index, unravel))

    flat = jnp.concatenate(flat_pytrees, axis=0)

    def retrieve_pytrees(flat):
        # Use static unravel functions
        return [unravel(flat) for unravel in static_unravels]

    return flat, retrieve_pytrees


def grid_line_search_factory(loss, steps):
    
    def loss_at_step(
            step, 
            params_u, 
            params_v, 
            tangent_params_u,
            tangent_params_v,
        ):
        updated_params_u = [(w - step * dw, b - step * db)
            for (w, b), (dw, db) in zip(params_u, tangent_params_u)]
        updated_params_v = [(w - step * dw, b - step * db)
            for (w, b), (dw, db) in zip(params_v, tangent_params_v)]
        return loss(updated_params_u, updated_params_v)
        
    v_loss_at_steps = jit(vmap(loss_at_step, (0, None, None, None, None)))    

    @jit
    def grid_line_search_update(
            params_u, 
            params_v, 
            tangent_params_u,
            tangent_params_v,
        ):
        losses = v_loss_at_steps(
            steps, 
            params_u, 
            params_v, 
            tangent_params_u, 
            tangent_params_v
        )
        step_size = steps[jnp.argmin(losses)]
        
        new_params_u = [(w - step_size * dw, b - step_size * db)
                for (w, b), (dw, db) in zip(params_u, tangent_params_u)]
        
        new_params_v = [(w - step_size * dw, b - step_size * db)
                for (w, b), (dw, db) in zip(params_v, tangent_params_v)]
        
        return new_params_u, new_params_v, step_size
    return grid_line_search_update

