import jax
import jax.lax as lax
import jax.numpy as jnp
from functools import partial
from environment import EnvParams

def _invert(i: int, params: EnvParams, x: jnp.ndarray):
    """Inverts the ith relator."""
    max_length = params.max_length
    ith_relator = lax.dynamic_slice(x, (i*max_length,), (max_length,))
    ith_relator_inverted = jnp.where(ith_relator != 0, -_reverse_nonzero(ith_relator), jnp.int8(0))
    return jax.lax.dynamic_update_slice(x, ith_relator_inverted, (i*max_length,)) # type: ignore

def _concatenate(i: int, j: int, params: EnvParams, x: jnp.ndarray):
    """Concatenates the ith and jth relators."""
    max_length = params.max_length
    ith_relator = lax.dynamic_slice(x, (i*max_length,), (max_length,))
    jth_relator = lax.dynamic_slice(x, (j*max_length,), (max_length,))

    # r_i = a c, r_j = C b a #
    ith_relator_reversed = _reverse_nonzero(ith_relator) # r_i (reversed) = c a 
    mask = (jth_relator == - ith_relator_reversed) # mask = (C b a == C a) = T F F

    num_cancel = jnp.argmin(mask) # 1 <-- = the number of elements that must cancel

    ith_len = jnp.count_nonzero(ith_relator) 
    jth_len = jnp.count_nonzero(jth_relator)
    new_size = ith_len + jth_len - 2 * num_cancel

    def do_nothing(x, ith_relator, jth_relator, ith_len, num_cancel, new_size):
        return x

    def update_x(x, ith_relator, jth_relator, ith_len, num_cancel, new_size):

        # mask1 and mask2 specify indices of updated_ith_relator
        # where elements of ith_relator and jth_relator should be copied
        positions = jnp.arange(max_length, dtype=jnp.int8)
        mask1 = jnp.zeros_like(positions, dtype=jnp.bool_)
        mask2 = jnp.zeros_like(positions, dtype=jnp.bool_)

        # ith_len = 2, num_cancel = 1, so at position = 0, set mask1 = 1.
        # new_size = 2 + 3 - 2 = 3, ith_len = 2, num_cancel = 1
        # so for positions >= 1 and positions < 3, i.e. positions = [1, 2], set mask2=1.
        mask1 = jnp.where(positions < ith_len - num_cancel, 1, 0) 
        mask2 = jnp.where(jnp.logical_and(positions >= ith_len - num_cancel, positions < new_size), 1, 0)

        # rotate jth_relator by ith_len - 2 * num_cancel = 2 - 2 = 0 elements. 
        # where mask2=True, i.e. positions = [1, 2], place the first and the second element,
        # i.e. b a (skipping C as C is to be cancelled.)
        # where mask1=True, place 
        updated_ith_relator = jnp.zeros_like(ith_relator)
        updated_ith_relator = jnp.where(
            mask2, 
            jnp.roll(jth_relator, ith_len - 2 * num_cancel),
            jnp.where(mask1, ith_relator, 0) # type: ignore
        )

        out = jax.lax.dynamic_update_slice(x, updated_ith_relator, (i*max_length,)) 
        return out

    out = jax.lax.cond(
        new_size > max_length,
        do_nothing,
        update_x,
        x, ith_relator, jth_relator, ith_len, num_cancel, new_size,
    )

    out = cyclic_reduce(i, params, out)

    return out

def cyclic_reduce(i: int, params: EnvParams, x: jnp.ndarray):
    """only need to reduce one relator; the one that was just modified: labelled by i."""
    n_gen = params.n_gen
    max_length = params.max_length

    # C a b c
    ith_relator = lax.dynamic_slice(x, (i*max_length,), (max_length,)) # C a b c
    ith_relator_reversed = _reverse_nonzero(ith_relator) # c b a C

    ith_len = jnp.count_nonzero(ith_relator)

    # 
    mask = (ith_relator == - ith_relator_reversed) # C a b c == C B A c --> T F F T

    # get index of the first F, eq. the total number of letters on each end to cancel
    indices = jnp.arange(max_length)
    num_cancel = jnp.min(jnp.where(~mask, indices, max_length))

    # We don't have to worry about length.
    # just copy [num_cancel: ith_len - num_cancel] = [1: 4-1] = [1: 3] = a b
    # at the beginning of the updated_ith_relator.
    rolled_indices = (indices + num_cancel) % max_length # [-1, 0, 1, 2]
    updated_ith_relator = jnp.where(
        indices >= ith_len - 2 * num_cancel, # 2, 3, ...
        jnp.zeros_like(ith_relator),
        ith_relator[rolled_indices]
    )

    out = jax.lax.dynamic_update_slice(x, updated_ith_relator, (i*max_length,)) 

    return out

def _conjugate(i: int, j: int, s: int, params: EnvParams, x: jnp.ndarray):
    """Conjugates the ith relator with the jth generator."""
    max_length = params.max_length
    n_gen = params.n_gen
    ith_relator = lax.dynamic_slice(x, (i*max_length,), (max_length,))
    gen = j * s
    
    new_arr = jnp.zeros((max_length+2,), dtype=x.dtype)
    new_arr = new_arr.at[0].set(gen)
    
    ith_rel_len = jnp.sum(ith_relator != 0)

    new_arr = lax.dynamic_update_slice(new_arr, ith_relator, (1,))
    new_arr = lax.dynamic_update_slice(new_arr, 
                                       jnp.array([-gen], dtype=new_arr.dtype), 
                                       (ith_rel_len+1,))

    mask_start_cancel = jnp.ones_like(new_arr, dtype=jnp.bool_)
    mask_end_cancel = jnp.ones_like(new_arr, dtype=jnp.bool_)

    # start cancel
    mask_start_cancel = jax.lax.cond(
        ith_relator[0] == -gen,
        lambda: mask_start_cancel.at[:2].set(jnp.zeros(2, dtype=jnp.bool_)), 
        lambda: mask_start_cancel
    )

    # end cancel
    # jax jit does not like dynamic indexing / slicing 
    # so we change the first two elements and then rotate
    mask_end_cancel = jax.lax.cond(
        ith_relator[ith_rel_len-1] == gen,
        lambda: mask_end_cancel.at[:2].set(jnp.zeros(2, dtype=jnp.bool_)),
        lambda: mask_end_cancel,
    )
    mask_end_cancel = jnp.roll(mask_end_cancel, ith_rel_len)

    # False if we need to remove an element. True otherwise.
    mask = jnp.logical_and(mask_start_cancel, mask_end_cancel)

    new_size = ith_rel_len + 2 - 2 * jnp.sum(jnp.logical_not(mask))

    def do_nothing(x, mask, new_arr, ith_relator):
        return x

    def update_x(x, mask, new_arr, ith_relator):
        # TODO: can we do away with result_arr?
        result_arr = jnp.where(
            mask,
            new_arr,
            0
        )

        new_result = jnp.zeros_like(ith_relator)
        new_result = jax.lax.cond(
            result_arr[0] == 0,
            lambda: result_arr[2:],
            lambda: result_arr[0:max_length],
        )

        out = lax.dynamic_update_slice(x, new_result, (i*max_length,))
        return out

    out = jax.lax.cond(
        new_size > max_length,
        do_nothing,
        update_x,
        x, mask, new_arr, ith_relator
    )

    out = cyclic_reduce(i, params, out)

    return out


def _reverse_nonzero(arr: jnp.ndarray):
    """Reverses the nonzero elements of the array."""
    nonzero_mask = arr != 0

    positions = jnp.arange(arr.shape[0])
    # Calculate new positions for non-zero elements
    # If the first 3 elements are non-zero in a length-5 array,
    # this maps [0,1,2,3,4] to [2,1,0,3,4]
    nonzero_count = jnp.sum(nonzero_mask)
    new_positions = jnp.where(
        nonzero_mask,
        nonzero_count - 1 - positions,
        positions
    )

    # Use the positions to create the reversed array
    reversed_arr = jnp.zeros_like(arr)
    reversed_arr = reversed_arr.at[new_positions].set(arr)
  
    return reversed_arr.astype(arr.dtype)

def _prime_concatenate(i: int, j: int, params: EnvParams, x: jnp.ndarray):
    x = _invert(j, params, x) 
    x = _concatenate(i, j, params, x) 
    x = _invert(j, params, x)
    return x

def setup_actions(params: EnvParams):
    """A helper function to package all the actions when you are not initiating
    an environment. 

    For example, when checking `AC_path` for a state:
    ```
    state = jnp.array([...]) # representing a presentation
    params = EnvParams(n_gen=2, max_length=len(state)//2)
    actions = setup_prime_actions(params)
    for i, move_id in enumerate(AC_path):
        state = actions[move_id](state)
    ```
    """
    jit_invert = jax.jit(_invert, static_argnames=("i", "params"))
    jit_concatenate = jax.jit(_concatenate, static_argnames=("i", "j", "params"))
    jit_conjugate = jax.jit(_conjugate, static_argnames=("i", "j", "s", "params"))

    _inverts = [partial(jit_invert, i, params) for i in range(params.n_gen)]
    
    _concatenations = [partial(jit_concatenate, i, j, params) for 
                            (i, j) in 
                            [(a, b) for a in range(params.n_gen) 
                            for b in range(params.n_gen) if a != b]]
    
    _conjugations = [partial(jit_conjugate, i, j, s, params) 
                            for i in range(params.n_gen) 
                            for j in range(1, params.n_gen+1) 
                            for s in (1, -1)]
       
    return _inverts + _concatenations + _conjugations


def setup_prime_actions(params: EnvParams):
    """
    A helper function similar to `setup_actions` to package primed actions into one list.
    """
    jit_prime_concatenate = jax.jit(_prime_concatenate, static_argnames=("i", "j", "params"))
    jit_concatenate = jax.jit(_concatenate, static_argnames=("i", "j", "params"))
    jit_conjugate = jax.jit(_conjugate, static_argnames=("i", "j", "s", "params"))

    _prime_concatenations = [partial(jit_prime_concatenate, i, j, params) for 
                            (i, j) in 
                            [(a, b) for a in range(params.n_gen) 
                            for b in range(params.n_gen) if a != b]]
    
    _concatenations = [partial(jit_concatenate, i, j, params) for 
                            (i, j) in 
                            [(a, b) for a in range(params.n_gen) 
                            for b in range(params.n_gen) if a != b]]
    
    _conjugations = [partial(jit_conjugate, i, j, s, params) 
                            for i in range(params.n_gen) 
                            for j in range(1, params.n_gen+1) 
                            for s in (1, -1)]
       
    return _prime_concatenations + _concatenations + _conjugations