# %%
import jax
from ac_moves import _invert, _conjugate, _concatenate, cyclic_reduce
import jax.numpy as jnp
import numpy as np
from ac import EnvParams

# %%
def test_invert():
    test_cases = [
        # Smallest possible case
        (0, jnp.array([1, 0, -1, 0]), jnp.array([-1, 0, -1, 0])),
        
        # Single nonzero element in each half
        (0, jnp.array([1, 0, 0, 0, -1, 0, 0, 0]), jnp.array([-1, 0, 0, 0, -1, 0, 0, 0])),
        (1, jnp.array([1, 0, 0, 0, -1, 0, 0, 0]), jnp.array([1, 0, 0, 0, 1, 0, 0, 0])),
        
        # Multiple nonzero elements, only first few
        (0, jnp.array([1, 2, 0, 0, -2, -1, 0, 0]), jnp.array([-2, -1, 0, 0, -2, -1, 0, 0])),
        (1, jnp.array([1, 2, 0, 0, -2, -1, 0, 0]), jnp.array([1, 2, 0, 0, 1, 2, 0, 0])),

        # First half is all zeros
        (0, jnp.array([0, 0, 0, 0, -2, -1, 0, 0]), jnp.array([0, 0, 0, 0, -2, -1, 0, 0])),
        (1, jnp.array([0, 0, 0, 0, -2, -1, 0, 0]), jnp.array([0, 0, 0, 0, 1, 2, 0, 0])),

        # Second half is all zeros
        (0, jnp.array([1, 2, 0, 0, 0, 0, 0, 0]), jnp.array([-2, -1, 0, 0, 0, 0, 0, 0])),
        (1, jnp.array([1, 2, 0, 0, 0, 0, 0, 0]), jnp.array([1, 2, 0, 0, 0, 0, 0, 0])),

        # Nonzero elements but different patterns
        (0, jnp.array([3, 1, 0, 0, -1, -3, 0, 0]), jnp.array([-1, -3, 0, 0, -1, -3, 0, 0])),
        (1, jnp.array([3, 1, 0, 0, -1, -3, 0, 0]), jnp.array([3, 1, 0, 0, 3, 1, 0, 0])),

        # Larger test cases
        (0, jnp.array([5, 4, 3, 2, 1, 0, 0, 0, -1, -2, -3, -4, -5, 0, 0, 0]), 
         jnp.array([-1, -2, -3, -4, -5, 0, 0, 0, -1, -2, -3, -4, -5, 0, 0, 0])),
        (1, jnp.array([5, 4, 3, 2, 1, 0, 0, 0, -1, -2, -3, -4, -5, 0, 0, 0]), 
         jnp.array([5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0])),

        (0, jnp.array([5, 4, 3, 2, 1, -1, -2, -3, -4, -5]), 
         jnp.array([-1, -2, -3, -4, -5, -1, -2, -3, -4, -5])),
        (1, jnp.array([5, 4, 3, 2, 1, -1, -2, -3, -4, -5]), 
         jnp.array([5, 4, 3, 2, 1, 5, 4, 3, 2, 1])),
    ]
    
    for i, x, expected_out in test_cases:
        expected_max_length = x.shape[0] // 2
        result = _invert(i, EnvParams(max_length=expected_max_length), x)
        assert np.array_equal(result, expected_out), f"Test failed for {x} at index {i}"
    
    print("All test cases passed!")

# Run the test function
test_invert()

# %%
def test_cyclic_reduce():
    test_cases = [
        # Smallest possible case
        (0, jnp.array([1, 0, -1, 0]), jnp.array([1, 0, -1, 0])),
        
        # Single nonzero element in each half
        (0, jnp.array([2, 1, -2, 0, -1, 0, 0, 0]), jnp.array([1, 0, 0, 0, -1, 0, 0, 0])),
        (1, jnp.array([1, 0, 0, 0, -2, -1, 2, 0]), jnp.array([1, 0, 0, 0, -1, 0, 0, 0])),
        
        # Multiple nonzero elements, only first few
        (0, jnp.array([-1, 2, 1, -2, 1, 0, -2, -1, 0, 0, 0, 0]), jnp.array([1, 0, 0, 0, 0, 0, -2, -1, 0, 0, 0, 0])),
        (1, jnp.array([-2, -1, 0, 0, 0, 0,-1, 2, 1, -2, 1, 0]), jnp.array([-2, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,])),

        # First half is all zeros
        (0, jnp.array([0, 0, 0, 0, -2, -1, 0, 0]), jnp.array([0, 0, 0, 0, -2, -1, 0, 0])),
        (1, jnp.array([0, 0, 0, 0, -2, -1, 0, 0]), jnp.array([0, 0, 0, 0, -2, -1, 0, 0])),

        # Second half is all zeros
        (0, jnp.array([1, 2, 0, 0, 0, 0, 0, 0]), jnp.array([1, 2, 0, 0, 0, 0, 0, 0])),
        (1, jnp.array([1, 2, 0, 0, 0, 0, 0, 0]), jnp.array([1, 2, 0, 0, 0, 0, 0, 0])),
    ]
    
    for i, x, expected_out in test_cases:
        assert x.shape[0] % 2 == 0, "input length must be even"
        expected_max_length = x.shape[0] // 2
        result = cyclic_reduce(i, EnvParams(max_length=expected_max_length), x)
        assert np.array_equal(result, expected_out), \
            f"Test failed for {x} at index {i}; expected {expected_out}, got {result}"
    
    print("All test cases passed!")

test_cyclic_reduce()


# %%
def test_conjugate(conjugate_fn, use_jax=False):
    test_cases = [
        (0, 2, 1, np.array([1, 2, 0, 2, 0, 0]), np.array([2, 1, 0, 2, 0, 0])),
        (1, 2, 1, np.array([2, 0, 0, 1, 2, 0]), np.array([2, 0, 0, 2, 1, 0])),
        (0, 2, 1, np.array([1, 0, 0, 2, 0, 0]), np.array([1, 0, 0, 2, 0, 0])), # requires cyclic_reduce
        (0, 2, 1, np.array([1, 1, 0, 2, 0, 0]), np.array([1, 1, 0, 2, 0, 0])), # requires cyclic_reduce
        (0, 2, 1, np.array([1, -2, 0, 0, 1, 0, 0, 0]), np.array([1, -2, 0, 0, 1, 0, 0, 0])), # requires cyclic_reduce
        (1, 2, 1, np.array([1, 0, 0, 0, 1, -2, 0, 0]), np.array([1, 0, 0, 0, 1, -2, 0, 0])),
        (0, 2, -1, np.array([2, 1, 0, 1, 0, 0]), np.array([1, 2, 0, 1, 0, 0])),
        (1, 2, 1, np.array([2, 0, 0, 0, 1, 2, 0, 0]), np.array([2, 0, 0, 0, 2, 1, 0, 0])),
        (0, 2, -1, np.array([2, 1, 0, 1, 0, 0]), np.array([1, 2, 0, 1, 0, 0])),
        (1, 1, -1, np.array([2, 0, 0, 0, 1, 2, 0, 0]), np.array([2, 0, 0, 0, 2, 1, 0, 0])),
        (0, 1, -1, np.array([1, 2, -1, 2, 0, 0]), np.array([2, 0, 0, 2, 0, 0])),
        (0, 2, 1, np.array([2, 0, 0, 0, 1, 2, 0, 0]), np.array([2, 0, 0, 0, 1, 2, 0, 0])),
        (1, 1, -1, np.array([1, 0, 0, 0, 1, 2, -1, 0]), np.array([1, 0, 0, 0, 2, 0, 0, 0])),
        (0, 2, 1, np.array([1, 2, -1, 2, 0, 0]), np.array([2, 0, 0, 2, 0, 0])),
    ]

    for idx, (i, j, sign, rels, expected_rels) in enumerate(test_cases):
        if use_jax:
            rels = jnp.array(rels)
            expected_rels = jnp.array(expected_rels)
        result_rels = conjugate_fn(i, j, sign, EnvParams(max_length=rels.shape[0]//2), rels)
        assert np.array_equal(result_rels, expected_rels), \
            f"""Test case {idx} failed: Resulting relators do not match expected results.
                Beginning relator: {rels}
                Expected: {expected_rels}
                Got: {result_rels}"""
    print("All test cases passed!")

test_conjugate(_conjugate, use_jax=True)
# %%
def test_concatenate():
    test_cases = [
        (np.array([1, 2, 0, 3, 4, 0]), 0, 1, np.array([1, 2, 0, 3, 4, 0])),
        (np.array([1, 2, 0, 0, 3, 4, 0, 0]), 0, 1, np.array([1, 2, 3, 4, 3, 4, 0, 0])),
        (np.array([1, 2, 0, 0, 3, 4, 0, 0]), 1, 0, np.array([1, 2, 0, 0, 3, 4, 1, 2])),
        (np.array([1, 2, 0, 0, -2, 1, 0, 0]), 0, 1, np.array([1, 1, 0, 0, -2, 1, 0, 0])),
        (np.array([1, 2, 0, 0, 1, -2, -1, 0]), 1, 0, np.array([1, 2, 0, 0, 1, 0, 0, 0])),
        (np.array([1, 1, 0, 0, 1, -2, 0, 0]), 0, 1, np.array([1, 1, 1, -2, 1, -2, 0, 0])),
        (np.array([1, 2, 0, 0, 1, -2, -1, 0]), 1, 0, np.array([1, 2, 0, 0, 1, 0, 0, 0])),
        (np.array([1, 2, 0, 0, 2, -1, 0, 0]), 0, 1, np.array([2, 2, 0, 0, 2, -1, 0, 0])), # requires cyclical reduce
        (np.array([1, 1, 2, 0, -2, 0, 0, 0]), 1, 0, np.array([1, 1, 2, 0, 1, 1, 0, 0])),  # requires cyclical reduce
        (np.array([1, 2, 0, 0, 0, -2, 1, 1, 2, -1]), 0, 1, np.array([1, 1, 2, 0, 0, -2, 1, 1, 2, -1])), # requires cyclical reduce
        (np.array([-2, 1, 1, 2, -1, 1, 2, 0, 0, 0,]), 1, 0, np.array([-2, 1, 1, 2, -1, 1, 1, 2, 0, 0])), # requires cyclical reduce
        (np.array([1, 0, 0, 1, 2, 3]), 0, 1, np.array([1, 0, 0, 1, 2, 3])),
        (np.array([1, 2, 3, 4, 5, 6]), 0, 1, np.array([1, 2, 3, 4, 5, 6])),
        (np.array([1, -2, 0, 2, -1, 2]), 0, 1, np.array([2, 0, 0, 2, -1, 2])),
    ]
    
    for rels, i, j, expected_rels in test_cases:
        result_rels = _concatenate(i, j, EnvParams(max_length=rels.shape[0]//2), rels)
        assert np.array_equal(result_rels, expected_rels), \
            f"""Relators do not match expected results when 
            rels = {rels}, 
            expected_rels = {expected_rels}, 
            got result = {result_rels}"""
    print("All test cases passed!")

test_concatenate()

# %%
