import numpy as np
from PIL import Image
import string

# Function to flip bits
def flip_bits(input, position, num_flip):
    # Convert input to NumPy array
    input_array = np.array(list(input), dtype='<U1')
    
    # Get indices not specified by position
    all_indices = np.arange(len(input_array))
    available_indices = np.setdiff1d(all_indices, position)

    # Randomly get num_flip indices
    flip_indices = np.random.choice(available_indices, num_flip, replace=False)

    # Flip specified indices
    input_array[flip_indices] = np.where(input_array[flip_indices] == '0', '1', '0')

    return ''.join(input_array), flip_indices

# position = [0, 1, 2, 3, 4, 5]
# input = '0000000000000000000'
# num_flip = 3
# print(flip_bits(input, position, num_flip))


# Function to reorder bits from bottom-right to top
def change_order(input, order_type=None):
    if order_type is None:
        raise ValueError("order_type must be specified")
    
    # Check if input length is a perfect square
    size = int(np.sqrt(len(input)))
    if size * size != len(input):
        raise ValueError("Input length must be a perfect square")
    
    # Reshape to square
    square = [list(input[i:i+size]) for i in range(0, len(input), size)]
    
    # Reordering
    rearranged = []
    
    # Top-left to right
    if order_type == 0:
        for row in range(size):
            for col in range(size):
                rearranged.append(square[row][col])
    # Top-left to bottom
    elif order_type == 1:
        for col in range(size):
            for row in range(size):
                rearranged.append(square[row][col])
    # Top-right to left
    elif order_type == 2:
        for row in range(size):
            for col in range(size-1, -1, -1):
                rearranged.append(square[row][col])
    # Top-right to bottom
    elif order_type == 3:
        for col in range(size-1, -1, -1):
            for row in range(size):
                rearranged.append(square[row][col])
    # Bottom-left to right
    elif order_type == 4:
        for row in range(size-1, -1, -1):
            for col in range(size):
                rearranged.append(square[row][col])
    # Bottom-left to top
    elif order_type == 5:
        for col in range(size):
            for row in range(size-1, -1, -1):
                rearranged.append(square[row][col])
    # Bottom-right to left
    elif order_type == 6:
        for row in range(size-1, -1, -1):
            for col in range(size-1, -1, -1):
                rearranged.append(square[row][col])
    # Bottom-right to top
    elif order_type == 7:
        for col in range(size-1, -1, -1):
            for row in range(size-1, -1, -1):
                rearranged.append(square[row][col])
    else:
        raise ValueError("Invalid order_type. Must be 0-7.")

    return ''.join(rearranged)


# input = '000111111'
# print(change_order(input))


# Function to highlight specified bits
def highlight(input, position_2d, scale_factor=10, highlight_color=(255, 0, 0)):
    # Convert bit sequence to image
    size = int(np.sqrt(len(input)))
    input = np.array(list(map(int, input)))
    input = input.reshape(size, size)
    input = 1 - input  # Invert (black-white inversion)

    # Highlight processing
    highlighted_image = np.stack([input * 255] * 3, axis=-1)
    highlighted_image[input == 1] = [255, 255, 255]  # White background
    rows, cols = zip(*position_2d)
    highlighted_image[rows, cols] = highlight_color  # Highlight in red

    # Convert image to PIL Image object
    highlighted_image = Image.fromarray(highlighted_image.astype(np.uint8))

    # Image scaling (scale_factor times)
    new_size = (
        highlighted_image.width * scale_factor,
        highlighted_image.height * scale_factor,
    )
    highlighted_image = highlighted_image.resize(
        new_size, Image.NEAREST
    )  # Scale with NEAREST

    return highlighted_image


# Function to convert bit sequence from box_size=1 to box_size=n
def change_box_size(input, box_size):
    size = int(np.sqrt(len(input)))
    input = np.array(list(map(int, input)))
    input = input.reshape(size, size)
    output = []
    for i in range(size):
        row = input[i]
        expanded_row = np.repeat(row, box_size, axis=0)
        for j in range(box_size):
            output.append(expanded_row)
    output = np.array(output).flatten()
    return output


# Function to convert bit sequence to image
def convert_to_image(input, add_quiet_zone=False):
    size = int(np.sqrt(len(input)))
    input = np.array(list(map(int, input)))
    input = input.reshape(size, size)
    ## Add quiet zone
    if add_quiet_zone:
        input = np.pad(input, pad_width=1, mode="constant", constant_values=0)
    image = ((1 - input) * 255).astype(np.uint8)
    image = Image.fromarray(image)
    return image


def burst_box(input, position, num_burst, burst_size, force_value=None):
    length = len(input)
    size = int(np.sqrt(length))  # Determine square size
    if size * size != length:
        raise ValueError("Input length must be a perfect square")

    input_array = np.array(list(input), dtype="<U1").reshape(size, size)
    position = set((p // size, p % size) for p in position)
    flipped_indices = set()

    burst_indices_list = []

    for _ in range(num_burst):
        valid_burst_found = False

        for _ in range(100):  # Maximum 100 attempts
            row = np.random.randint(0, size - burst_size[0] + 1)
            col = np.random.randint(0, size - burst_size[1] + 1)

            burst_indices = {
                (r, c)
                for r in range(row, row + burst_size[0])
                for c in range(col, col + burst_size[1])
            }

            if burst_indices & position or burst_indices & flipped_indices:
                continue  # Retry if unchangeable bits or already flipped bits are included

            valid_burst_found = True
            burst_indices_list.append(burst_indices)
            flipped_indices |= burst_indices
            break

    if len(burst_indices_list) < num_burst:
        print("Failed to find valid burst indices")
        return None

    all_results = {}
    force_type = ["flip", "force0", "force1"]
    for force in force_type:
        for burst_indices in burst_indices_list:
            for r, c in burst_indices:
                if force == "flip":
                    input_array[r, c] = "1" if input_array[r, c] == "0" else "0"
                elif force == "force0":
                    input_array[r, c] = "0"
                elif force == "force1":
                    input_array[r, c] = "1"
        all_results[force] = "".join(input_array.flatten()), sorted(flipped_indices)

    return all_results


def burst_horizontal(
    input, position, num_burst, burst_size, force_value=None
):
    length = len(input)
    size = int(np.sqrt(length))  # Determine square size
    if size * size != length:
        raise ValueError("Input length must be a perfect square")

    input_array = np.array(list(input), dtype="<U1").reshape(size, size)
    position = set((p // size, p % size) for p in position)
    flipped_indices = set()

    burst_indices_list = []

    for _ in range(num_burst):
        valid_burst_found = False

        for _ in range(100):  # Maximum 100 attempts
            row = np.random.randint(0, size)
            col = np.random.randint(0, size - burst_size + 1)

            burst_indices = {(row, c) for c in range(col, min(col + burst_size, size))}

            if burst_indices & position or burst_indices & flipped_indices:
                continue  # Retry if unchangeable bits or already flipped bits are included

            valid_burst_found = True
            burst_indices_list.append(burst_indices)
            flipped_indices |= burst_indices
            break
        
    if len(burst_indices_list) < num_burst:
        print("Failed to find valid burst indices")
        return None
    
    all_results = {}
    force_type = ["flip", "force0", "force1"]
    for force in force_type:
        for burst_indices in burst_indices_list:
            for r, c in burst_indices:
                if force == "flip":
                    input_array[r, c] = "1" if input_array[r, c] == "0" else "0"
                elif force == "force0":
                    input_array[r, c] = "0"
                elif force == "force1":
                    input_array[r, c] = "1"
            all_results[force] = "".join(input_array.flatten()), sorted(flipped_indices)
    return all_results


def burst_vertical(input, position, num_burst, burst_size, force_value=None):
    length = len(input)
    size = int(np.sqrt(length))  # Determine square size
    if size * size != length:
        raise ValueError("Input length must be a perfect square")

    input_array = np.array(list(input), dtype="<U1").reshape(size, size)
    position = set((p // size, p % size) for p in position)
    flipped_indices = set() 
    
    burst_indices_list = []
    
    for _ in range(num_burst):
        valid_burst_found = False
        
        for _ in range(100):  # Maximum 100 attempts
            row = np.random.randint(0, size - burst_size + 1)
            col = np.random.randint(0, size)
            
            burst_indices = {(r, col) for r in range(row, min(row + burst_size, size))}
            
            if burst_indices & position or burst_indices & flipped_indices:
                continue
                
            valid_burst_found = True
            burst_indices_list.append(burst_indices)
            flipped_indices |= burst_indices
            break
            
    if len(burst_indices_list) < num_burst:
        print("Failed to find valid burst indices")
        return None
        
    all_results = {}
    force_type = ["flip", "force0", "force1"]
    for force in force_type:
        for burst_indices in burst_indices_list:
            for r, c in burst_indices:
                if force == "flip":
                    input_array[r, c] = "1" if input_array[r, c] == "0" else "0"
                elif force == "force0":
                    input_array[r, c] = "0"
                elif force == "force1":
                    input_array[r, c] = "1"
            all_results[force] = "".join(input_array.flatten()), sorted(flipped_indices)
            
    return all_results


def random_replace_alphanum_np(domain: str, num_replacements: int, seed: int = None) -> str:
    if seed is not None:
        np.random.seed(seed)

    # Target: Get indices of alphabets and numbers
    valid_indices = [i for i, c in enumerate(domain) if c.isalpha() or c.isdigit()]
    
    if num_replacements > len(valid_indices):
        raise ValueError("Too many replacement characters. Exceeds the number of alphanumeric characters in the domain.")
    
    # Randomly select indices to replace
    indices_to_replace = np.random.choice(valid_indices, size=num_replacements, replace=False)
    
    domain_list = list(domain)

    for idx in indices_to_replace:
        original_char = domain_list[idx]
        
        if original_char.isalpha():
            # For alphabets
            candidates = [c for c in string.ascii_lowercase if c != original_char.lower()]
            new_char = np.random.choice(candidates)
            domain_list[idx] = new_char.upper() if original_char.isupper() else new_char

        elif original_char.isdigit():
            # For numbers
            candidates = [c for c in string.digits if c != original_char]
            new_char = np.random.choice(candidates)
            domain_list[idx] = new_char
    
    return ''.join(domain_list)


if __name__ == "__main__":
    # Test code
    input_text = "Hello.123"
    num_replacements = 3
    modified_text = random_replace_alphanum_np(input_text, num_replacements)
    print(modified_text)