import torch
import math
import itertools
from tqdm import tqdm


def make_perm_family(L: int, n_exponent: int) -> torch.Tensor:
    """
    Returns 2^n_exponent permutation matrices (L x L).
    Always includes Id (pattern=0) and one corresponding to Rev (pattern=2**(n_exponent-1)).
    If n_exponent = 0, only Id.
    If n_exponent >= 1, it is necessary that 2**(n_exponent-1) <= L.

    Args:
        L: Number of elements to be permuted (target length)
        n_exponent: Exponent for generating 2**n_exponent permutations

    Returns:
        torch.Tensor: A stack of Permutation Matrices with shape (2**n_exponent, L, L)
    """
    if n_exponent == 0:
        # Id only
        return torch.eye(L).unsqueeze(0)

    # The check for 2**(n_exponent-1) <= L is based on the premise of the original code,
    # and for the current permutation_select_num (max 16, n_exponent=4) and L (max 50),
    # 2**(4-1)=8 <= 50, so it is always satisfied, so the assert can be omitted, but it is left just in case.
    # However, care must be taken that L // (2**(n-1)) in the original code's step calculation does not become 0.
    # Avoid step becoming 0 when L < 2**(n_exponent-1).
    # For example, in the case of L=3, n_exponent=3 (2**2=4 > 3).
    # In this case, step = 0, and the shift is always 0.
    # The original code's assert 2**(n-1) <= L covers this problem.
    if n_exponent > 0:
        assert (
            2 ** (n_exponent - 1) <= L
        ), f"n_exponent ({n_exponent}) is too large. It must satisfy 2**({n_exponent}-1) <= L ({L})."

    num_perms_half = 2 ** (n_exponent - 1)
    # original code's step
    # If L < num_perms_half, L // num_perms_half becomes 0.
    # To prevent this, set step to at least 1 and make the cyclic shift meaningful.
    # If L is small, many shifts may produce the same Permutation.
    # This function is intended to generate 2**n_exponent "different" Permutations.
    # In the original code, variations are created by shifting while including Id and Rev.

    perm_mats = []

    # r=0: Forward order based shift
    # r=1: Reverse order based shift
    for r_flag in range(2):  # 0 for forward-based, 1 for reverse-based
        for s_count in range(num_perms_half):
            # Aim to uniquely distribute the shift amount within the range of L.
            # With the original code's step = max(1, L // num_perms_half), the number of shift types decreases when L is small.
            # Example: L=4, num_perms_half=4 (n_exponent=3) -> step = 1. shifts: 0,1,2,3
            # Example: L=4, num_perms_half=2 (n_exponent=2) -> step = 2. shifts: 0,2
            # Example: L=4, num_perms_half=1 (n_exponent=1) -> step = 4 (->0). shifts: 0
            # s_count here is from 0 to num_perms_half - 1
            # Simply set the shift amount to s_count and circulate with % L.
            # However, this is likely to overlap with Id and Rev.
            # Respect the intent of the original code and maintain the step-based shift.
            # However, avoid the step=0 problem when L < num_perms_half.
            if L < num_perms_half:
                # If L is smaller than num_perms_half, fix step to 1,
                # and try to get some different shifts by limiting s_count to less than L.
                # However, the number of generated Permutations may be less than 2**n_exponent.
                # In this case, it depends on the design whether to allow duplicates or to make it an error.
                # Here, I will try to avoid step=0 in a form close to the behavior of the original code.
                # If L < num_perms_half, you can only create L types of shift variations.
                # Since the purpose of this function is to create 2**n_exponent pieces, we will accept duplicates.
                current_step = 1
                shift = (s_count * current_step) % L
            else:
                current_step = L // num_perms_half  # No possibility of division by zero (if n_exponent >= 1, then num_perms_half >= 1)
                if current_step == 0:
                    current_step = 1  # Case of L < num_perms_half
                shift = (s_count * current_step) % L

            perm_indices = torch.zeros(L, dtype=torch.long)
            for i in range(L):
                if r_flag == 1:  # Reverse-based
                    perm_indices[i] = L - 1 - ((i + shift) % L)
                else:  # Forward-based
                    perm_indices[i] = (i + shift) % L

            current_perm_matrix = torch.eye(L, dtype=torch.float)[perm_indices]
            perm_mats.append(current_perm_matrix)

    return torch.stack(perm_mats)


def get_permutations(target_len: int, permutation_select_num: int) -> torch.Tensor:
    """
    Generate the specified number of Permutation Matrices.
    permutation_select_num must be a power of 2.
    """
    if permutation_select_num <= 0:
        raise ValueError("permutation_select_num must be positive.")
    if (
        not ((permutation_select_num & (permutation_select_num - 1) == 0) and permutation_select_num != 0)
        and permutation_select_num != 1
    ):  # 1 is also allowed
        raise ValueError(f"permutation_select_num must be a power of 2, but got {permutation_select_num}")

    if permutation_select_num == 1:
        n_exp = 0
    else:
        n_exp = int(math.log2(permutation_select_num))

    # Consider the assert in the case of L < 2**(n_exponent-1) of make_perm_family
    if n_exp > 0 and target_len < 2 ** (n_exp - 1):
        # For example, if target_len=3, perm_select_num=8 (n_exp=3), then 2**(3-1)=4 > 3
        # In this case, the number of generated Permutation types will be less than expected or an error will occur.
        # Asserted in make_perm_family.
        # Here, either issue a warning or adjust n_exp.
        # This time, I will leave it to the assert of make_perm_family.
        pass

    return make_perm_family(L=target_len, n_exponent=n_exp)


def generate_all_permutation_matrices(N: int):
    """
    Generate all N! permutation matrices of size NxN.
    Returns:
        A tensor of shape (N!, N, N), where each [i] is a permutation matrix.
    """
    perms = list(itertools.permutations(range(N)))  # N! permutations
    matrices = []

    for perm in tqdm(perms):
        mat = torch.zeros(N, N)
        for i, j in enumerate(perm):
            mat[i, j] = 1.0
        matrices.append(mat)

    return torch.stack(matrices)  # Shape: (N!, N, N)


def generate_random_permutation(N: int, num_samples: int = 32, seed: int = 42):
    """
    Generates multiple random permutation matrices with PyTorch tensors based on the specified seed value.

    Args:
      N (int): Length of the input sequence (the size of the generated matrix is N x N).
      num_samples (int): The number of permutation matrices to generate.
      seed (int): The seed value for random number generation.

    Returns:
      torch.Tensor: A tensor of the generated permutation matrices. The shape is (num_samples, N, N).
    """
    # Set a random seed for reproducibility
    torch.manual_seed(seed)

    matrices = []
    for _ in range(num_samples):
        # Generate a random permutation from 0 to N-1
        # Using torch.randperm() gives a random permutation without duplicates
        permutation = torch.randperm(N)

        # Create a unit matrix tensor
        # Specifying the data type as an integer will improve memory efficiency
        identity_tensor = torch.eye(N, dtype=torch.int)

        # Create a permutation matrix by rearranging the rows of the identity matrix according to the permutation
        permutation_tensor = identity_tensor[permutation]

        matrices.append(permutation_tensor)

    # Combine the tensors in the list into one tensor
    return torch.stack(matrices)


if __name__ == "__main__":
    # Test cases
    L_val = 5
    print(f"--- L={L_val} ---")
    for num_p in [1, 2, 4, 8]:
        if num_p == 8 and L_val < 2 ** (int(math.log2(num_p)) - 1):  # 8 (n=3) -> 2^(2)=4. L_val < 4
            print(f"Skipping num_p={num_p} for L={L_val} due to assertion in make_perm_family")
            continue
        try:
            perms = get_permutations(L_val, num_p)
            print(f"num_p={num_p}, shape={perms.shape}")
            if num_p == 2:
                print("Permutation 0 (Identity):")
                print(perms[0])
                print("Permutation 1 (Reverse):")
                print(perms[1])  # In the case of two, expect the first to be Id and the second to be Rev
        except ValueError as e:
            print(f"Error for num_p={num_p}: {e}")
        except AssertionError as e:
            print(f"AssertionError for num_p={num_p}: {e}")

    L_val = 3
    print(f"--- L={L_val} ---")
    for num_p in [1, 2, 4, 8]:
        # n_exp for num_p=4 is 2. 2**(2-1)=2. L_val=3 >= 2. OK.
        # n_exp for num_p=8 is 3. 2**(3-1)=4. L_val=3 < 4. Error.
        if num_p == 8 and L_val < 2 ** (int(math.log2(num_p)) - 1):  # 8 (n=3) -> 2^(2)=4. L_val < 4
            print(f"Skipping num_p={num_p} for L={L_val} due to assertion in make_perm_family")
            continue
        try:
            perms = get_permutations(L_val, num_p)
            print(f"num_p={num_p}, shape={perms.shape}")
        except ValueError as e:
            print(f"Error for num_p={num_p}: {e}")
        except AssertionError as e:
            print(f"AssertionError for num_p={num_p}: {e}")

    # relu n=50, square mod19 n=50, index n=13 m=4
    # Target lengths: 50, 50, 4
    perms_relu_2 = get_permutations(target_len=50, permutation_select_num=2)
    print(f"relu, num_p=2, shape={perms_relu_2.shape}")
    perms_relu_16 = get_permutations(target_len=50, permutation_select_num=16)
    print(f"relu, num_p=16, shape={perms_relu_16.shape}")

    perms_index_2 = get_permutations(target_len=4, permutation_select_num=2)
    print(f"index m=4, num_p=2, shape={perms_index_2.shape}")
    # perms_index_8 = get_permutations(target_len=4, permutation_select_num=8) # This should fail: L=4, n_exp=3 -> 2**(3-1)=4. 4 <= 4. OK.
    # print(f"index m=4, num_p=8, shape={perms_index_8.shape}")
    # make_perm_family(L=4, n_exponent=3) -> num_perms_half = 4. step = 4 // 4 = 1.
    # s_count = 0,1,2,3.
    # r=0: shift=0,1,2,3
    # r=1: shift=0,1,2,3
    # This should be OK.

    perms_index_8_L4 = get_permutations(target_len=4, permutation_select_num=8)
    print(f"index m=4, num_p=8 (L=4), shape={perms_index_8_L4.shape}")
    # 2^(3-1) = 4. L=4.  4 <= 4 is true. So assert should pass.

    # Test for L < 2**(n_exponent-1) case that should fail assert
    try:
        print("Test L=3, num_p=16 (n_exp=4)")
        get_permutations(target_len=3, permutation_select_num=16)  # n_exp=4 -> 2**(4-1)=8. 3 < 8. Should fail.
    except AssertionError as e:
        print(f"Correctly caught AssertionError: {e}")


# Added for Hierarchical Search v1
def factorial_util(n):  # Renamed to avoid conflict if math.factorial is also used directly
    """Helper factorial function for generate_next_level_permutations_v1 if not using math.factorial directly."""
    if n < 0:
        raise ValueError("Factorial is not defined for negative numbers")
    if n == 0:
        return 1
    res = 1
    for i in range(1, n + 1):
        res *= i
    return res


def generate_next_level_permutations_v1(base_permutations: list, num_blocks_to_create: int, num_elements: int) -> list:
    """
    Generates the next level of permutations based on the v1 policy.
    Each base_permutation is split into 'num_blocks_to_create' blocks,
    and these blocks are then permuted in all possible 'num_blocks_to_create!' orders.

    Args:
        base_permutations (list): A list of permutations (each permutation is a list of ints).
        num_blocks_to_create (int): The number of blocks to split each permutation into.
                                     This also determines the factorial number of new permutations generated from each base_permutation.
        num_elements (int): The total number of elements in each permutation (e.g., M).

    Returns:
        list: A new list of permutations generated from the base_permutations.
    """
    next_level_permutations = []

    if num_blocks_to_create <= 0:
        return [list(p) for p in base_permutations]
    if num_blocks_to_create == 1:
        return [list(p) for p in base_permutations]

    for p_base in base_permutations:
        if not p_base or len(p_base) != num_elements:
            continue

        blocks = []
        block_size_base = num_elements // num_blocks_to_create
        remainder = num_elements % num_blocks_to_create
        current_pos = 0

        for i in range(num_blocks_to_create):
            size = block_size_base + (1 if i < remainder else 0)
            if size == 0 and num_elements > 0 and i < num_elements:
                blocks.append([])
            elif size > 0:
                blocks.append(p_base[current_pos : current_pos + size])
            else:
                blocks.append([])
            current_pos += size

        while len(blocks) < num_blocks_to_create:
            blocks.append([])

        block_indices = list(range(num_blocks_to_create))
        for block_order_indices in itertools.permutations(block_indices):
            new_perm = []
            current_element_count = 0
            valid_construction = True
            for index in block_order_indices:
                if index < len(blocks):
                    new_perm.extend(blocks[index])
                    current_element_count += len(blocks[index])
                else:
                    valid_construction = False
                    break

            if valid_construction and current_element_count == num_elements:
                next_level_permutations.append(new_perm)
                # Add the reversed permutation
                # new_perm_reversed = new_perm[::-1]
                # next_level_permutations.append(new_perm_reversed)
                # breakpoint()

    # breakpoint()
    return next_level_permutations


def generate_next_level_permutations_v2(base_permutations: list, num_blocks_to_create: int, num_elements: int) -> list:
    """
    Generates the next level of permutations based on the v1 policy.
    Each base_permutation is split into 'num_blocks_to_create' blocks,
    and these blocks are then permuted in all possible 'num_blocks_to_create!' orders.

    Args:
        base_permutations (list): A list of permutations (each permutation is a list of ints).
        num_blocks_to_create (int): The number of blocks to split each permutation into.
                                     This also determines the factorial number of new permutations generated from each base_permutation.
        num_elements (int): The total number of elements in each permutation (e.g., M).

    Returns:
        list: A new list of permutations generated from the base_permutations.
    """
    next_level_permutations = []

    if num_blocks_to_create <= 0:
        return [list(p) for p in base_permutations]
    if num_blocks_to_create == 1:
        return [list(p) for p in base_permutations]

    for p_base in base_permutations:
        if not p_base or len(p_base) != num_elements:
            continue

        blocks = []
        block_size_base = num_elements // num_blocks_to_create
        remainder = num_elements % num_blocks_to_create
        current_pos = 0

        for i in range(num_blocks_to_create):
            size = block_size_base + (1 if i < remainder else 0)
            if size == 0 and num_elements > 0 and i < num_elements:
                blocks.append([])
            elif size > 0:
                blocks.append(p_base[current_pos : current_pos + size])
            else:
                blocks.append([])
            current_pos += size

        while len(blocks) < num_blocks_to_create:
            blocks.append([])

        # --- Updated logic based on misc/202505.md for v2 ---
        # Add the original permutation (blocks in their original order)
        current_perm_from_blocks = []
        for block_content in blocks:
            current_perm_from_blocks.extend(block_content)

        if len(current_perm_from_blocks) == num_elements:
            next_level_permutations.append(current_perm_from_blocks)

        # Generate (num_blocks_to_create - 1) additional permutations by swapping adjacent blocks
        # Total permutations will be 1 (original) + (num_blocks_to_create - 1) (swapped) = num_blocks_to_create
        if num_blocks_to_create > 1:  # Swapping is only possible if there are at least 2 blocks
            for i in range(num_blocks_to_create - 1):
                # Create a new list of blocks for swapping, starting from the original 'blocks' configuration
                swapped_block_list = list(
                    blocks
                )  # Shallow copy of the list of lists (block contents are lists/sequences)

                # Swap block i and block i+1
                swapped_block_list[i], swapped_block_list[i + 1] = swapped_block_list[i + 1], swapped_block_list[i]

                new_perm = []
                # Reconstruct the permutation from the swapped block list
                for block_content in swapped_block_list:
                    new_perm.extend(block_content)

                if len(new_perm) == num_elements:  # Ensure consistency
                    next_level_permutations.append(new_perm)
        # --- End of updated logic ---
        
    # breakpoint()

    return next_level_permutations