import os
import random

import numpy as np


class BitSet:
    def __init__(self, size=1024**2):
        self.size = size
        self.bitset = np.zeros(self.size, dtype=bool)

    def _ensure_capacity(self, num):
        """Ensure the bitset has enough capacity to store the specified number"""
        if num >= self.size:
            # Expand bitset size
            new_size = max(num + 1, self.size * 2)
            new_bitset = np.zeros(new_size, dtype=bool)
            new_bitset[: self.size] = self.bitset
            self.bitset = new_bitset
            self.size = new_size
            print(f"enlarge size to {self.size}")

    def add(self, num):
        """Add a number to the bitset"""
        self._ensure_capacity(num)
        self.bitset[num] = True

    def remove(self, num):
        """Remove a number from the bitset"""
        if num < self.size:
            self.bitset[num] = False

    def contains(self, num):
        """Check if the bitset contains a number"""
        return num < self.size and self.bitset[num]

    def __contains__(self, num):
        return self.contains(num)

    def update(self, iterable_or_bitset):
        """Update the current bitset with elements from an iterable or another BitSet"""
        if isinstance(iterable_or_bitset, BitSet):
            # If the argument is a BitSet, use numpy's vectorized operation to update
            self._ensure_capacity(iterable_or_bitset.size)
            self.bitset[: iterable_or_bitset.size] |= iterable_or_bitset.bitset
        else:
            # If the argument is an iterable, iterate and add each element
            for num in iterable_or_bitset:
                self.add(num)

    def __sub__(self, other):
        """Implement subtraction operator, using numpy vectorized operations to efficiently create a new bitset"""
        # Create a new bitset instance
        result = BitSet(max(self.size, other.size))
        # Use numpy's vectorized logical operations
        result.bitset[: self.size] = self.bitset & ~other.bitset[: self.size]
        return result

    def __isub__(self, other):
        """Implement in-place subtraction operator, using numpy vectorized operations for efficient element removal"""
        # First ensure that other's size doesn't exceed the current bitset's size
        min_size = min(self.size, other.size)
        # Use numpy's vectorized logical operations for element removal
        self.bitset[:min_size] &= ~other.bitset[:min_size]
        return self

    def __str__(self):
        """Return the bitset as a string, listing all true indices"""
        # Find all true indices
        true_indices = np.where(self.bitset)[0]
        # Convert these indices to a string and join with commas
        indices_str = ", ".join(map(str, true_indices))
        return f"BitSet({indices_str})"

    def __len__(self):
        """Return the number of true elements in the bitset"""
        return self.bitset.sum()

    def capacity(self):
        return self.size

    def density(self):
        return len(self) / self.size

    def memory_usage(self):
        """Return the memory usage of the bitset in KB, MB, or GB"""
        bytes_usage = self.bitset.nbytes
        if bytes_usage < 1024:
            return f"{bytes_usage} B"
        elif bytes_usage < 1024**2:
            return f"{bytes_usage / 1024:.2f} KB"
        elif bytes_usage < 1024**3:
            return f"{bytes_usage / 1024**2:.2f} MB"
        else:
            return f"{bytes_usage / 1024**3:.2f} GB"

    def to_list(self):
        """Return a list of all true indices"""
        return list(np.where(self.bitset)[0])

    def save(self, filename):
        """Save the bitset to a file"""

        def random_hash():
            """Return a random hash value"""
            return random.randint(0, 2**64 - 1)

        filename_with_suffix = filename + ".{}.npy".format(random_hash())
        dirname = os.path.dirname(filename_with_suffix)
        os.makedirs(dirname, exist_ok=True)
        np.save(filename_with_suffix, self.bitset)
        return filename_with_suffix

    @classmethod
    def load(cls, filename_with_suffix):
        """Load a bitset from a file and create a new BitSet instance"""
        bitset_array = np.load(filename_with_suffix)
        bitset = cls(bitset_array.size)
        bitset.bitset = bitset_array
        return bitset


def bitset_diff(normal_set, bitset):
    """Return a set of elements that exist in normal_set but not in bitset"""
    ret = {elem for elem in normal_set if not bitset.contains(elem)}
    return ret


if __name__ == "__main__":
    # Example usage
    bitset1 = BitSet(1024)
    bitset1.update([100, 200, 300, 1023])

    bitset2 = BitSet(1024)
    bitset2.update([100, 400, 1023])

    result_bitset = bitset1 - bitset2
    print(100 in result_bitset)  # Should output False
    print(200 in result_bitset)  # Should output True
    print(300 in result_bitset)  # Should output True
    print(1023 in result_bitset)  # Should output False

    bitset1 -= bitset2
    print(result_bitset)  # BitSet(200, 300)
    print(bitset1)  # BitSet(200, 300)
    print(bitset2)  # BitSet(100, 400, 1023)

    bitsetlarge = BitSet(1024**3)
    print(len(bitsetlarge), bitsetlarge.capacity(), bitsetlarge.density(), bitset1.density())
    print("BitSet memory usage:", bitsetlarge.memory_usage())

    print(bitset_diff({100, 200}, bitset2))

    bitset1.update(bitset2)
    bitsetlarge.add(52260134)
    bitset2.update(bitsetlarge)
    print(bitset1)  # BitSet(100, 200, 300, 400, 1023)
    print(bitset2)  # BitSet(100, 400, 1023, 52260134)
