from collections import defaultdict
from typing import Any
import numpy as np
from scipy.sparse import csr_matrix, hstack
from time_monitor import time_monitor

@time_monitor
def create_instance_vector(attr_value_index_map: defaultdict[Any, str], instance: dict[str, str]) -> csr_matrix:
    array = np.zeros(len(attr_value_index_map))
    for key in instance:
        array[int(attr_value_index_map[(key, instance[key])])] = 1
    return csr_matrix(array, dtype=np.int8)

@time_monitor
def update_attr_to_indices_map(attr_value_index_map: defaultdict[Any, str]) -> dict[str, list[int]]:
    attr_to_indices_map = defaultdict(list)
    for (attr, value), index in attr_value_index_map.items():
        attr_to_indices_map[attr].append(int(index))
    # Convert to a regular dict with sorted lists for consistency (optional but good practice)
    return {attr: sorted(indices) for attr, indices in attr_to_indices_map.items()}


@time_monitor
def pad_csr_vector(matrix: csr_matrix, pad_width: int, constant_value: int = 0) -> csr_matrix:
    """
    Pads a CSR matrix with a specified number of columns on the right,
    filling them with a constant value.
    """
    # If there's nothing to pad, return the original matrix immediately.
    if pad_width <= 0:
        return matrix

    # A sparse matrix implicitly represents zeros, so we only need to change its shape.
    if constant_value == 0:
        new_shape = (matrix.shape[0], matrix.shape[1] + pad_width)
        # Re-create the matrix with the new shape. This reuses the existing data,
        # indices, and indptr arrays, making it extremely fast.
        return csr_matrix((matrix.data, matrix.indices, matrix.indptr), shape=new_shape, dtype=np.int8)

    # For non-zero values, we must create a new block of data and stack it.
    else:
        # Get the number of rows from the original matrix.
        num_rows = matrix.shape[0]

        # Create the padding block. We can create a dense numpy array of the
        # constant value and then convert it to a sparse CSR matrix.
        padding_block_dense = np.full((num_rows, pad_width), constant_value, dtype=matrix.dtype)
        padding_block_sparse = csr_matrix(padding_block_dense, dtype=np.int8)

        # hstack is the efficient way to join sparse matrices horizontally.
        return hstack([matrix, padding_block_sparse], format='csr')

@time_monitor
def pad_bit_vector(bit_vector: int, pad_width: int, pad_with_one: bool = False) -> int:
    """
    Pads a vector with a number of zeros or ones.

    1. Padding with zeros is a simple left shift.
    2. Padding with ones requires a left shift and then OR-ing a mask of ones.
    """

    # Shift left to make space for the padded bits
    padded_int = bit_vector << pad_width

    if pad_with_one:
        # Create a mask of `num_pads` ones
        mask = (1 << pad_width) - 1
        # Apply the mask
        padded_int |= mask

    return padded_int


def pad_bit_vector2(bit_vector: int, original_len: int, pad_width: int, pad_with_one: bool = False) -> tuple[int, int]:
    """
    Pads a bit vector on the LEFT side (prepends to the beginning).
    This corresponds to adding bits to the MOST significant side of the integer.

    Args:
        bit_vector: The integer representation of the vector.
        original_len: The number of bits in the original vector. THIS IS REQUIRED.
        pad_width: The number of bits to add.
        pad_with_one: If True, pads with 1s; otherwise, pads with 0s.

    Returns:
        A tuple containing (the new integer, the new length).

    Example: [1,1,0] (len 3) padded with 2 zeros -> [0,0,1,1,0] (0b110 -> 0b00110)
    Example: [1,1,0] (len 3) padded with 2 ones -> [1,1,1,1,0] (0b110 -> 0b11110)
    """
    new_len = original_len + pad_width

    # Case 1: Padding with zeros on the left.
    # This doesn't change the integer's value, only its conceptual length.
    if not pad_with_one:
        return bit_vector, new_len

    # Case 2: Padding with ones on the left.
    # 1. Create a block of `pad_width` ones (e.g., 0b11 for pad_width=2).
    mask = (1 << pad_width) - 1

    # 2. Shift this block of ones to the left, past all the original bits.
    shifted_mask = mask << original_len

    # 3. OR the mask with the original number to place the ones.
    padded_vector = bit_vector | shifted_mask

    return padded_vector, new_len
