# src/data_handler.py
import torch
import math
import random
from tqdm import tqdm
import numpy as np
import logging # Import logging
from typing import List, Dict, Any, Callable, Tuple
from typing import List, Dict, Any, Set, Tuple

from target_functions import TARGET_FUNCTIONS

logger = logging.getLogger(__name__) # Get logger instance

try:
    from sympy import nextprime, isprime
except ImportError:
    print("Error: sympy library not found. Please run 'pip install sympy'")
    exit()

random.seed(42)

try:
    import gmpy2
except ImportError as e:
    raise ImportError("gmpy2 is required for this fast generator. Install with `pip install gmpy2`.") from e


class FastPrimeDataGenerator:
    """Generates decimal input vectors and corresponding prime/non-prime outputs (50/50) using gmpy2."""
    def __init__(self,
                 sequence_length: int,
                 num_samples: int,
                 device: str = 'cpu'):  # device kept for API consistency
        if not isinstance(sequence_length, int) or sequence_length <= 0:
            logger.error(f"Sequence length must be a positive integer, but got {sequence_length}.")
            raise ValueError("Sequence length must be a positive integer.")

        if not isinstance(num_samples, int) or num_samples <= 0:
            logger.error(f"Number of samples must be a positive integer, but got {num_samples}.")
            raise ValueError("Number of samples must be a positive integer.")

        if num_samples % 2 != 0:
            logger.warning(f"num_samples is odd ({num_samples}). For a 50/50 split, it's recommended to be even.")

        self.sequence_length = sequence_length
        self.num_samples = num_samples
        self.device = device

        # Define the search range for numbers with the given length
        self.start_range = 10**(self.sequence_length - 1) if self.sequence_length > 1 else 1
        self.end_range = 10**self.sequence_length  # exclusive upper bound

        max_possible = self.end_range - self.start_range
        if self.num_samples > max_possible:
            logger.error(f"Requested {self.num_samples} samples, but only {max_possible} unique numbers of length {self.sequence_length} exist.")
            raise ValueError(f"Requested {self.num_samples} samples, but only {max_possible} unique numbers are possible.")

        logger.info(f"PrimeDataGenerator (gmpy2) initialized for len={sequence_length}, samples={num_samples}, device='{device}'")

    def _rand_ndigit(self) -> int:
        """Uniform random integer with exactly sequence_length digits."""
        return random.randrange(self.start_range, self.end_range)

    def _generate_unique_numbers(self) -> Tuple[List[int], List[int]]:
        """
        Generates unique sets of prime and non-prime numbers using gmpy2.
        Uses random sampling + primality testing (no multiprocessing).
        """
        num_primes_needed = self.num_samples // 2
        num_non_primes_needed = self.num_samples - num_primes_needed

        logger.info(f"(gmpy2) Generating {num_primes_needed} unique primes and {num_non_primes_needed} unique non-primes...")

        primes_found: Set[int] = set()
        non_primes_found: Set[int] = set()
        all_found: Set[int] = set()

        # Progress logging cadence
        last_logged = 0
        log_step = max(1000, self.num_samples // 100)  # ~1% steps or at least 1k

        # --- Collect primes ---
        while len(primes_found) < num_primes_needed:
            # Start at a random point and jump to the next prime; average gap ~ ln(10^L)
            start = self._rand_ndigit()
            candidate = int(gmpy2.next_prime(start))
            if candidate >= self.end_range:
                continue  # fell out of digit-length range; resample

            if candidate not in all_found:
                primes_found.add(candidate)
                all_found.add(candidate)

            total = len(all_found)
            if total - last_logged >= log_step:
                logger.info(f"(gmpy2) Progress: {total}/{self.num_samples} numbers found "
                            f"({len(primes_found)} primes, {len(non_primes_found)} non-primes).")
                last_logged = total

        # --- Collect non-primes ---
        while len(non_primes_found) < num_non_primes_needed:
            n = self._rand_ndigit()
            # Reject if prime; 99.78% will be composite at 200 digits
            if gmpy2.is_prime(n):
                continue
            if n not in all_found:
                non_primes_found.add(n)
                all_found.add(n)

            total = len(all_found)
            if total - last_logged >= log_step:
                logger.info(f"(gmpy2) Progress: {total}/{self.num_samples} numbers found "
                            f"({len(primes_found)} primes, {len(non_primes_found)} non-primes).")
                last_logged = total

        logger.info(f"(gmpy2) Successfully generated {len(primes_found)} primes and {len(non_primes_found)} non-primes.")

        # Shuffle within each split (optional, you shuffle final dataset anyway)
        prime_list = list(primes_found)
        non_prime_list = list(non_primes_found)
        random.shuffle(prime_list)
        random.shuffle(non_prime_list)
        return (prime_list, non_prime_list)

    def generate_data(self) -> List[Dict[str, Any]]:
        """Generates the dataset as a list of dictionaries."""
        primes, non_primes = self._generate_unique_numbers()

        dataset = []
        logger.info("Formatting dataset...")

        # Keep exact output format compatible with your previous code:
        for p in primes:
            input_vector = np.array(list(str(p)))
            dataset.append({'Input': input_vector, 'Output': '1'})

        for np_num in non_primes:
            input_vector = np.array(list(str(np_num)))
            dataset.append({'Input': input_vector, 'Output': '0'})

        random.shuffle(dataset)
        logger.info("Final dataset shuffled and creation complete.")
        return dataset



class CoprimeRandomSplitDecimalDataGenerator:
    """
    Generates fixed-length DECIMAL input vectors with an exact 50/50 split:
      - label '1' -> the two parts from a RANDOM split of the decimal string are coprime
      - label '0' -> not coprime

    For each sample:
      1) Draw an integer in the configured range and zero-pad it to `sequence_length`.
      2) Choose a random split index s in [1, sequence_length-1].
      3) Let a = int(dec_str[:s]), b = int(dec_str[s:]); label = '1' iff gcd(a, b) == 1.

    Notes:
      - Requires sequence_length >= 2 (so a split exists).
      - If allow_leading_zeros=True, parts may be zero; gcd rules apply (e.g., gcd(x,0)=|x|).
      - Ensures UNIQUE full decimal sequences across the dataset.
    """

    def __init__(self,
                 sequence_length: int,
                 num_samples: int,
                 device: str = 'cpu',
                 allow_leading_zeros: bool = False):
        if not isinstance(sequence_length, int) or sequence_length <= 0:
            logger.error(f"Sequence length must be a positive integer, but got {sequence_length}.")
            raise ValueError("Sequence length must be a positive integer.")
        if sequence_length < 2:
            logger.error("sequence_length must be at least 2 to allow a nontrivial split.")
            raise ValueError("sequence_length must be at least 2 to allow a nontrivial split.")
        if not isinstance(num_samples, int) or num_samples <= 0:
            logger.error(f"Number of samples must be a positive integer, but got {num_samples}.")
            raise ValueError("Number of samples must be a positive integer.")
        if num_samples % 2 != 0:
            logger.error(f"num_samples must be even to guarantee 50/50 split, but got {num_samples}.")
            raise ValueError("num_samples must be even to guarantee a 50/50 split.")

        self.sequence_length = sequence_length
        self.num_samples = num_samples
        self.device = device
        self.allow_leading_zeros = allow_leading_zeros

        # Numeric range for fixed-length DECIMAL numbers
        if allow_leading_zeros:
            self.start_range = 0
            self.end_range = 10 ** self.sequence_length  # exclusive
        else:
            self.start_range = 10 ** (self.sequence_length - 1)
            self.end_range = 10 ** self.sequence_length

        max_possible = self.end_range - self.start_range
        if self.num_samples > max_possible:
            logger.error(
                f"Requested {self.num_samples} samples, but only {max_possible} unique "
                f"decimal sequences of length {self.sequence_length} exist under current settings."
            )
            raise ValueError(
                f"Requested {self.num_samples} samples, but only {max_possible} unique sequences are possible."
            )

        logger.info(
            f"CoprimeRandomSplitDecimalDataGenerator initialized for len={sequence_length}, "
            f"samples={num_samples}, device='{device}', allow_leading_zeros={allow_leading_zeros}"
        )

    @staticmethod
    def _is_coprime(a: int, b: int) -> bool:
        return math.gcd(a, b) == 1

    def _generate_unique_numbers(self) -> Tuple[List[int], List[int]]:
        """Generate exactly half coprime (w.r.t. a random split) and half not-coprime examples."""
        num_pos = self.num_samples // 2
        num_neg = self.num_samples - num_pos

        logger.info(f"Generating {num_pos} coprime and {num_neg} non-coprime (via random split)...")

        pos_found: Set[int] = set()
        neg_found: Set[int] = set()
        all_found: Set[int] = set()

        max_attempts = (self.num_samples * 80) + 2000  # generous cap
        attempts = 0
        last_logged = -1

        while len(pos_found) < num_pos or len(neg_found) < num_neg:
            attempts += 1
            if attempts > max_attempts:
                msg = (f"Could not generate required samples after {max_attempts} attempts. "
                       f"Found {len(pos_found)}/{num_pos} positives and {len(neg_found)}/{num_neg} negatives.")
                logger.error(msg)
                raise RuntimeError(msg)

            n = random.randint(self.start_range, self.end_range - 1)
            if n in all_found:
                continue

            dec_str = str(n).zfill(self.sequence_length)
            # random split point in [1, L-1]
            s = random.randint(1, self.sequence_length - 1)
            a = int(dec_str[:s])
            b = int(dec_str[s:])
            is_co = self._is_coprime(a, b)

            if is_co and len(pos_found) < num_pos:
                pos_found.add(n)
                all_found.add(n)
            elif (not is_co) and len(neg_found) < num_neg:
                neg_found.add(n)
                all_found.add(n)

            # progress logging
            current = len(all_found)
            if current - last_logged >= max(1, int(self.num_samples * 0.1)):
                logger.info(f"Generation progress (coprime random split): {current}/{self.num_samples} found.")
                last_logged = current

        logger.info(
            f"Generated {len(pos_found)} positives (coprime) and {len(neg_found)} negatives (not coprime)."
        )
        return list(pos_found), list(neg_found)

    def generate_data(self) -> List[Dict[str, Any]]:
        """Return a list of dicts with 'Input' (np.array of decimal chars) and 'Output' ('1' or '0')."""
        pos, neg = self._generate_unique_numbers()

        dataset: List[Dict[str, Any]] = []
        logger.info("Formatting decimal dataset (coprime via random split)...")

        def fmt_dec(n: int) -> str:
            return str(n).zfill(self.sequence_length)

        # Positives: coprime (for the random split used at generation time)
        for p in pos:
            s = fmt_dec(p)
            dataset.append({'Input': np.array(list(s)), 'Output': '1'})

        # Negatives: not coprime
        for n in neg:
            s = fmt_dec(n)
            dataset.append({'Input': np.array(list(s)), 'Output': '0'})

        random.shuffle(dataset)
        logger.info("Final decimal dataset (50/50, coprime random split) created and shuffled.")
        return dataset



class PrimeDecimalNonPrimeNoSmallDivisorsDataGenerator:
    """
    Generates fixed-length DECIMAL input vectors with a strict 50/50 split:
      - label '1' -> prime numbers
      - label '0' -> non-prime numbers that have NO divisor in the given list
                     (default: 2,3,5,11,13,17,19,23)

    Each input is a zero-padded decimal string of length `sequence_length`,
    returned as np.array(list(str_digits)).
    """

    def __init__(self,
                 sequence_length: int,
                 num_samples: int,
                 device: str = 'cpu',
                 allow_leading_zeros: bool = False,
                 forbidden_divisors: Tuple[int, ...] = (2, 3, 5, 11, 13, 17, 19, 23)):
        if not isinstance(sequence_length, int) or sequence_length <= 0:
            logger.error(f"Sequence length must be a positive integer, but got {sequence_length}.")
            raise ValueError("Sequence length must be a positive integer.")
        if not isinstance(num_samples, int) or num_samples <= 0:
            logger.error(f"Number of samples must be a positive integer, but got {num_samples}.")
            raise ValueError("Number of samples must be a positive integer.")
        if num_samples % 2 != 0:
            logger.error(f"num_samples must be even to guarantee 50/50 split, but got {num_samples}.")
            raise ValueError("num_samples must be even to guarantee a 50/50 split.")

        # Validate divisor list
        clean_divs = sorted(set(int(d) for d in forbidden_divisors))
        if any(d <= 1 for d in clean_divs):
            raise ValueError("forbidden_divisors must all be > 1.")
        self.forbidden_divisors: Tuple[int, ...] = tuple(clean_divs)

        self.sequence_length = sequence_length
        self.num_samples = num_samples
        self.device = device
        self.allow_leading_zeros = allow_leading_zeros

        # Numeric range for fixed-length DECIMAL numbers
        if allow_leading_zeros:
            self.start_range = 0
            self.end_range = 10 ** self.sequence_length  # exclusive
        else:
            if self.sequence_length == 1:
                self.start_range = 1
                self.end_range = 10
            else:
                self.start_range = 10 ** (self.sequence_length - 1)
                self.end_range = 10 ** self.sequence_length

        max_possible = self.end_range - self.start_range
        if self.num_samples > max_possible:
            logger.error(
                f"Requested {self.num_samples} samples, but only {max_possible} unique "
                f"decimal sequences of length {self.sequence_length} exist under current settings."
            )
            raise ValueError(
                f"Requested {self.num_samples} samples, but only {max_possible} unique sequences are possible."
            )

        # Precompute product for fast coprimality check (gcd(candidate, product) == 1)
        self._forbidden_product = 1
        for d in self.forbidden_divisors:
            self._forbidden_product *= d

        logger.info(
            f"PrimeDecimalNonPrimeNoSmallDivisorsDataGenerator initialized: "
            f"len={sequence_length}, samples={num_samples}, device='{device}', "
            f"allow_leading_zeros={allow_leading_zeros}, "
            f"forbidden_divisors={self.forbidden_divisors}"
        )

    def _valid_negative(self, n: int) -> bool:
        """Non-prime and not divisible by ANY forbidden divisor."""
        if isprime(n):
            return False
        # Quick elimination for 0 which is divisible by all positive integers
        if n == 0:
            return False
        # gcd(n, product(forbidden_divisors)) == 1  => not divisible by any of them
        return math.gcd(n, self._forbidden_product) == 1

    def _generate_unique_numbers(self) -> Tuple[List[int], List[int]]:
        """Generate exactly half primes and half restricted non-primes."""
        num_pos = self.num_samples // 2
        num_neg = self.num_samples - num_pos

        logger.info(f"Generating {num_pos} primes and {num_neg} non-primes with no small divisors...")

        primes_found: Set[int] = set()
        non_primes_found: Set[int] = set()
        all_found: Set[int] = set()

        # Slightly higher cap due to stricter negative filter
        max_attempts = (self.num_samples * 80) + 2000
        attempts = 0
        last_logged = -1

        # Guard: if the range is extremely small, early fail might be clearer
        if self.end_range - self.start_range < self.num_samples:
            logger.warning("Search space is tight relative to requested samples; may be hard to satisfy uniqueness.")

        while len(primes_found) < num_pos or len(non_primes_found) < num_neg:
            attempts += 1
            if attempts > max_attempts:
                msg = (f"Could not generate required samples after {max_attempts} attempts. "
                       f"Found {len(primes_found)}/{num_pos} primes and "
                       f"{len(non_primes_found)}/{num_neg} restricted non-primes.")
                logger.error(msg)
                raise RuntimeError(msg)

            # --- Prime search (if needed) ---
            if len(primes_found) < num_pos:
                # nextprime from a random start; clamp to avoid stepping out of range
                rnd_start = random.randint(self.start_range, max(self.start_range, self.end_range - 2))
                candidate = nextprime(rnd_start)
                if self.start_range <= candidate < self.end_range and candidate not in all_found:
                    primes_found.add(candidate)
                    all_found.add(candidate)

            # --- Restricted non-prime search (if needed) ---
            if len(non_primes_found) < num_neg:
                candidate = random.randint(self.start_range, self.end_range - 1)
                if candidate not in all_found and self._valid_negative(candidate):
                    non_primes_found.add(candidate)
                    all_found.add(candidate)

            # Periodic progress logging
            current = len(all_found)
            if current - last_logged >= max(1, int(self.num_samples * 0.1)):
                logger.info(f"Generation progress (decimal no-small-divisors): {current}/{self.num_samples} found.")
                last_logged = current

        logger.info(
            f"Generated {len(primes_found)} primes and {len(non_primes_found)} restricted non-primes (no small divisors)."
        )
        return list(primes_found), list(non_primes_found)

    def generate_data(self) -> List[Dict[str, Any]]:
        """Return a list of dicts with 'Input' (np.array of decimal chars) and 'Output' ('1' or '0')."""
        primes, non_primes = self._generate_unique_numbers()

        dataset: List[Dict[str, Any]] = []
        logger.info("Formatting decimal dataset (no small divisors)...")

        def fmt_dec(n: int) -> str:
            return str(n).zfill(self.sequence_length)

        # Positives: primes
        for p in primes:
            s = fmt_dec(p)
            dataset.append({'Input': np.array(list(s)), 'Output': '1'})

        # Negatives: non-primes with no divisor in forbidden_divisors
        for n in non_primes:
            s = fmt_dec(n)
            # Defensive invariants
            assert not isprime(n)
            assert math.gcd(n, self._forbidden_product) == 1
            dataset.append({'Input': np.array(list(s)), 'Output': '0'})

        random.shuffle(dataset)
        logger.info("Final decimal dataset (50/50, no small divisors) created and shuffled.")
        return dataset



class PrimeDecimalTailRestrictedDataGenerator:
    """
    Generates fixed-length DECIMAL input vectors with a strict 50/50 split:
      - label '1' -> prime numbers
      - label '0' -> non-prime numbers whose LAST DECIMAL DIGIT ∈ {1, 3, 7, 9}

    Input vectors are arrays of characters for the zero-padded decimal string.
    """

    def __init__(self,
                 sequence_length: int,
                 num_samples: int,
                 device: str = 'cpu',
                 allow_leading_zeros: bool = False,
                 allowed_nonprime_last_digits: Tuple[int, ...] = (1, 3, 7, 9)):
        if not isinstance(sequence_length, int) or sequence_length <= 0:
            logger.error(f"Sequence length must be a positive integer, but got {sequence_length}.")
            raise ValueError("Sequence length must be a positive integer.")
        if not isinstance(num_samples, int) or num_samples <= 0:
            logger.error(f"Number of samples must be a positive integer, but got {num_samples}.")
            raise ValueError("Number of samples must be a positive integer.")
        if num_samples % 2 != 0:
            logger.error(f"num_samples must be even to guarantee 50/50 split, but got {num_samples}.")
            raise ValueError("num_samples must be even to guarantee a 50/50 split.")
        if not all(d in {0,1,2,3,4,5,6,7,8,9} for d in allowed_nonprime_last_digits):
            raise ValueError("allowed_nonprime_last_digits must be decimal digits 0..9")

        self.sequence_length = sequence_length
        self.num_samples = num_samples
        self.device = device
        self.allow_leading_zeros = allow_leading_zeros
        self.allowed_nonprime_last_digits = tuple(sorted(set(allowed_nonprime_last_digits)))

        # Define numeric range for fixed-length DECIMAL numbers
        if allow_leading_zeros:
            self.start_range = 0
            self.end_range = 10 ** self.sequence_length  # exclusive
        else:
            if self.sequence_length == 1:
                self.start_range = 1
                self.end_range = 10
            else:
                self.start_range = 10 ** (self.sequence_length - 1)
                self.end_range = 10 ** self.sequence_length

        max_possible = self.end_range - self.start_range
        if self.num_samples > max_possible:
            logger.error(
                f"Requested {self.num_samples} samples, but only {max_possible} unique "
                f"decimal sequences of length {self.sequence_length} exist under current settings."
            )
            raise ValueError(
                f"Requested {self.num_samples} samples, but only {max_possible} unique sequences are possible."
            )

        logger.info(
            f"PrimeDecimalTailRestrictedDataGenerator initialized for len={sequence_length}, "
            f"samples={num_samples}, device='{device}', allow_leading_zeros={allow_leading_zeros}, "
            f"allowed_nonprime_last_digits={self.allowed_nonprime_last_digits}"
        )

    def _generate_unique_numbers(self) -> Tuple[List[int], List[int]]:
        """Generate exactly half primes and half NON-primes ending with allowed digits."""
        num_pos = self.num_samples // 2
        num_neg = self.num_samples - num_pos

        logger.info(f"Generating {num_pos} primes and {num_neg} restricted non-primes (decimal domain)...")

        primes_found: Set[int] = set()
        non_primes_found: Set[int] = set()
        all_found: Set[int] = set()

        # Plenty of headroom for random sampling in large spaces
        max_attempts = (self.num_samples * 50) + 1000
        attempts = 0
        last_logged = -1

        while len(primes_found) < num_pos or len(non_primes_found) < num_neg:
            attempts += 1
            if attempts > max_attempts:
                msg = (f"Could not generate required samples after {max_attempts} attempts. "
                       f"Found {len(primes_found)}/{num_pos} primes and "
                       f"{len(non_primes_found)}/{num_neg} restricted non-primes.")
                logger.error(msg)
                raise RuntimeError(msg)

            # --- Prime search (if needed) ---
            if len(primes_found) < num_pos:
                rnd_start = random.randint(self.start_range, max(self.start_range, self.end_range - 2))
                candidate = nextprime(rnd_start)
                if self.start_range <= candidate < self.end_range and candidate not in all_found:
                    primes_found.add(candidate)
                    all_found.add(candidate)

            # --- Restricted non-prime search (if needed) ---
            if len(non_primes_found) < num_neg:
                candidate = random.randint(self.start_range, self.end_range - 1)
                if (candidate % 10) in self.allowed_nonprime_last_digits:
                    if (not isprime(candidate)) and (candidate not in all_found):
                        non_primes_found.add(candidate)
                        all_found.add(candidate)

            # Periodic progress logging
            current = len(all_found)
            if current - last_logged >= max(1, int(self.num_samples * 0.1)):
                logger.info(f"Generation progress (decimal): {current}/{self.num_samples} sequences found.")
                last_logged = current

        logger.info(
            f"Generated {len(primes_found)} primes and {len(non_primes_found)} restricted non-primes."
        )
        return list(primes_found), list(non_primes_found)

    def generate_data(self) -> List[Dict[str, Any]]:
        """Return a list of dicts with 'Input' (np.array of decimal chars) and 'Output' ('1' or '0')."""
        primes, non_primes = self._generate_unique_numbers()

        dataset: List[Dict[str, Any]] = []
        logger.info("Formatting decimal dataset...")

        def fmt_dec(n: int) -> str:
            # Zero-pad to fixed length for consistent sequence length
            return str(n).zfill(self.sequence_length)

        # Positives (label '1'): primes (no tail restriction)
        for p in primes:
            dec_str = fmt_dec(p)
            input_vector = np.array(list(dec_str))
            dataset.append({'Input': input_vector, 'Output': '1'})

        # Negatives (label '0'): non-primes with last digit in {1,3,7,9}
        for n in non_primes:
            dec_str = fmt_dec(n)
            # Defensive assert to guarantee invariant
            assert int(dec_str[-1]) in self.allowed_nonprime_last_digits and not isprime(n)
            input_vector = np.array(list(dec_str))
            dataset.append({'Input': input_vector, 'Output': '0'})

        random.shuffle(dataset)
        logger.info("Final decimal dataset (50/50) created and shuffled.")
        return dataset


class PrimeBinaryDataGenerator:
    """
    Generates fixed-length binary input vectors and prime/non-prime outputs.

    A binary sequence of length L represents the integer value int(seq, 2).
    The label is '1' iff that integer is prime, else '0'.

    Notes:
    - By default, sequences have NO leading zeros (MSB=1) to truly be length-L numbers.
      Set allow_leading_zeros=True if you want sequences like '00101' for L=5.
    - Output labels are strings '1' / '0' (binary labels), same as your decimal generator.
    """

    def __init__(self,
                 sequence_length: int,
                 num_samples: int,
                 device: str = 'cpu',
                 allow_leading_zeros: bool = True):  # optional, defaults to leading zeros
        if not isinstance(sequence_length, int) or sequence_length <= 0:
            logger.error(f"Sequence length must be a positive integer, but got {sequence_length}.")
            raise ValueError("Sequence length must be a positive integer.")
        if not isinstance(num_samples, int) or num_samples <= 0:
            logger.error(f"Number of samples must be a positive integer, but got {num_samples}.")
            raise ValueError("Number of samples must be a positive integer.")
        if num_samples % 2 != 0:
            logger.warning(f"num_samples is odd ({num_samples}). For a 50/50 split, it's recommended to be even.")

        self.sequence_length = sequence_length
        self.num_samples = num_samples
        self.device = device
        self.allow_leading_zeros = allow_leading_zeros

        # Define numeric range for fixed-length binary numbers
        if allow_leading_zeros:
            # All bitpatterns of length L are allowed; we'll zfill to L bits.
            self.start_range = 0
            self.end_range = 1 << self.sequence_length  # exclusive upper bound
        else:
            # Enforce MSB=1 so the integer really has bit-length L (no leading zeros).
            if self.sequence_length == 1:
                # Only '1' exists; cannot produce any primes -> not feasible for 50/50
                logger.error("sequence_length=1 with no leading zeros cannot produce primes and a 50/50 split.")
                raise ValueError("sequence_length=1 with no leading zeros cannot support prime/non-prime 50/50.")
            self.start_range = 1 << (self.sequence_length - 1)  # e.g., 1000...0b
            self.end_range = 1 << self.sequence_length         # e.g., 1000...0b shifted one more

        max_possible = self.end_range - self.start_range
        if self.num_samples > max_possible:
            logger.error(f"Requested {self.num_samples} samples, but only {max_possible} unique sequences of length {self.sequence_length} exist under current settings.")
            raise ValueError(f"Requested {self.num_samples} samples, but only {max_possible} unique sequences are possible.")

        logger.info(f"PrimeBinaryDataGenerator initialized for len={sequence_length}, samples={num_samples}, "
                    f"device='{device}', allow_leading_zeros={allow_leading_zeros}")

    def _generate_unique_numbers(self) -> Tuple[List[int], List[int]]:
        """Generates unique sets of prime and non-prime integers in the configured range."""
        num_primes_needed = self.num_samples // 2
        num_non_primes_needed = self.num_samples - num_primes_needed

        logger.info(f"Generating {num_primes_needed} unique primes and {num_non_primes_needed} unique non-primes (binary domain)...")

        primes_found: Set[int] = set()
        non_primes_found: Set[int] = set()
        all_found: Set[int] = set()

        # Safety cap: binary space is dense; this should be plenty.
        max_attempts = (self.num_samples * 10) + 100
        attempts = 0
        last_logged_count = -1

        while len(primes_found) < num_primes_needed or len(non_primes_found) < num_non_primes_needed:
            attempts += 1
            if attempts > max_attempts:
                msg = (f"Could not generate required samples after {max_attempts} attempts. "
                       f"Found {len(primes_found)}/{num_primes_needed} primes and "
                       f"{len(non_primes_found)}/{num_non_primes_needed} non-primes.")
                logger.error(msg)
                raise RuntimeError(msg)

            # --- Prime search (if needed) ---
            if len(primes_found) < num_primes_needed:
                rnd_start = random.randint(self.start_range, self.end_range - 2)
                candidate = nextprime(rnd_start)
                if candidate < self.end_range and candidate not in all_found:
                    primes_found.add(candidate)
                    all_found.add(candidate)

            # --- Non-prime search (if needed) ---
            if len(non_primes_found) < num_non_primes_needed:
                candidate = random.randint(self.start_range, self.end_range - 1)
                if not isprime(candidate) and candidate not in all_found:
                    non_primes_found.add(candidate)
                    all_found.add(candidate)

            # Periodic progress logging
            current_count = len(all_found)
            if current_count - last_logged_count >= self.num_samples * 0.1:
                logger.info(f"Generation progress (binary): {current_count}/{self.num_samples} sequences found.")
                last_logged_count = current_count

        logger.info(f"Successfully generated {len(primes_found)} primes and {len(non_primes_found)} non-primes in binary domain.")
        return list(primes_found), list(non_primes_found)

    def generate_data(self) -> List[Dict[str, Any]]:
        """Returns a list of dicts with 'Input' (np.array of '0'/'1') and 'Output' ('1' or '0')."""
        primes, non_primes = self._generate_unique_numbers()

        dataset: List[Dict[str, Any]] = []
        logger.info("Formatting binary dataset...")

        # Helper to format an integer as a zero-padded binary string of fixed length
        def fmt_bin(n: int) -> str:
            return format(n, 'b').zfill(self.sequence_length)

        # Add primes
        for p in primes:
            bin_str = fmt_bin(p)
            input_vector = np.array(list(bin_str))  # array(['0','1',...])
            dataset.append({'Input': input_vector, 'Output': '1'})

        # Add non-primes
        for n in non_primes:
            bin_str = fmt_bin(n)
            input_vector = np.array(list(bin_str))
            dataset.append({'Input': input_vector, 'Output': '0'})

        random.shuffle(dataset)
        logger.info("Final binary dataset shuffled and creation complete.")
        return dataset

class PrimeDataGenerator:
    """Generates decimal input vectors and corresponding prime/non-prime outputs."""
    def __init__(self,
                 sequence_length: int,
                 num_samples: int,
                 device: str = 'cpu'): # device is kept for API consistency

        if not isinstance(sequence_length, int) or sequence_length <= 0:
            logger.error(f"Sequence length must be a positive integer, but got {sequence_length}.")
            raise ValueError("Sequence length must be a positive integer.")
        
        if not isinstance(num_samples, int) or num_samples <= 0:
            logger.error(f"Number of samples must be a positive integer, but got {num_samples}.")
            raise ValueError("Number of samples must be a positive integer.")

        if num_samples % 2 != 0:
            logger.warning(f"num_samples is odd ({num_samples}). For a 50/50 split, it's recommended to be even.")

        self.sequence_length = sequence_length
        self.num_samples = num_samples
        self.device = device
        
        # Define the search range for numbers with the given length
        self.start_range = 10**(self.sequence_length - 1) if self.sequence_length > 1 else 1
        self.end_range = 10**self.sequence_length

        max_possible = self.end_range - self.start_range
        if self.num_samples > max_possible:
            logger.error(f"Requested {self.num_samples} samples, but only {max_possible} unique numbers of length {self.sequence_length} exist.")
            raise ValueError(f"Requested {self.num_samples} samples, but only {max_possible} unique numbers are possible.")

        logger.info(f"PrimeDataGenerator initialized for len={sequence_length}, samples={num_samples}, device='{device}'")

    def _generate_unique_numbers(self) -> Tuple[List[int], List[int]]:
        """
        Generates unique sets of prime and non-prime numbers using a randomized approach.
        """
        num_primes_needed = self.num_samples // 2
        num_non_primes_needed = self.num_samples - num_primes_needed
        
        logger.info(f"Generating {num_primes_needed} unique primes and {num_non_primes_needed} unique non-primes...")
        
        primes_found: Set[int] = set()
        non_primes_found: Set[int] = set()
        all_found: Set[int] = set()
        
        max_attempts = (self.num_samples*10) + 10 # Safety break for sparse ranges
        attempts = 0
        last_logged_count = -1

        while len(primes_found) < num_primes_needed or len(non_primes_found) < num_non_primes_needed:
            attempts += 1
            if attempts > max_attempts:
                msg = (f"Could not generate required samples after {max_attempts} attempts. "
                       f"Found {len(primes_found)}/{num_primes_needed} primes and "
                       f"{len(non_primes_found)}/{num_non_primes_needed} non-primes.")
                logger.error(msg)
                raise RuntimeError(msg)

            # --- Find a Prime Number (if needed) ---
            if len(primes_found) < num_primes_needed:
                # To get a random prime, we start searching from a random point in the range.
                random_start = random.randint(self.start_range, self.end_range - 2)
                candidate = nextprime(random_start)
                
                # Ensure the found prime is within our digit length and is unique
                if candidate < self.end_range and candidate not in all_found:
                    primes_found.add(candidate)
                    all_found.add(candidate)

            # --- Find a Non-Prime Number (if needed) ---
            if len(non_primes_found) < num_non_primes_needed:
                candidate = random.randint(self.start_range, self.end_range - 1)
                
                # Ensure the candidate is not prime and is unique
                if not isprime(candidate) and candidate not in all_found:
                    non_primes_found.add(candidate)
                    all_found.add(candidate)

            # Log progress periodically
            current_count = len(all_found)
            if current_count - last_logged_count >= self.num_samples * 0.1:
                logger.info(f"Generation progress: {current_count}/{self.num_samples} numbers found.")
                last_logged_count = current_count
        
        logger.info(f"Successfully generated {len(primes_found)} primes and {len(non_primes_found)} non-primes.")
        
        prime_list = list(primes_found)
        non_prime_list = list(non_primes_found)
        
        logger.info("Shuffled generated numbers.")
        return (prime_list, non_prime_list)

    def generate_data(self) -> List[Dict[str, Any]]:
        """Generates the dataset as a list of dictionaries."""
        primes, non_primes = self._generate_unique_numbers()

        dataset = []
        logger.info("Formatting dataset...")

        # Process primes and format them as requested
        for p in primes:
            input_vector = np.array(list(str(p))) # Each digit as an element
            dataset.append({
                'Input': input_vector,
                'Output': '1'  # 1 indicates a prime number
            })

        # Process non-primes
        for np_num in non_primes:
            input_vector = np.array(list(str(np_num)))
            dataset.append({
                'Input': input_vector,
                'Output': '0'  # 0 indicates a non-prime number
            })
            
        # Shuffle the combined dataset to mix primes and non-primes randomly
        random.shuffle(dataset)
        logger.info("Final dataset shuffled and creation complete.")
        
        return dataset
    

# class BinaryDataGenerator:
#     """Generates binary input vectors and corresponding outputs from a target function."""
#     def __init__(self,
#                  target_function_name: str,
#                  sequence_length: int,
#                  num_samples: int,
#                  device: str = 'cpu'):

#         if target_function_name not in TARGET_FUNCTIONS:
#              logger.error(f"Unknown target function: {target_function_name}. Available: {list(TARGET_FUNCTIONS.keys())}")
#              raise ValueError(f"Unknown target function: {target_function_name}.")
#         self.target_function = TARGET_FUNCTIONS[target_function_name]
#         self.target_function_name = target_function_name
#         self.sequence_length = sequence_length
#         self.num_samples = num_samples
#         self.device = device
#         logger.info(f"DataGenerator initialized for target='{target_function_name}', len={sequence_length}, samples={num_samples}, device='{device}'")

#     def _generate_unique_binary_samples(self) -> torch.Tensor:
#         """Generates unique binary sequences efficiently."""
#         max_possible = 2**self.sequence_length
#         if self.num_samples > max_possible:
#              logger.error(f"Requested {self.num_samples} samples, but only {max_possible} unique sequences possible.")
#              raise ValueError(f"Requested {self.num_samples} samples, but only {max_possible} unique sequences possible.")

#         logger.info(f"Generating {self.num_samples} unique binary samples...")
#         # Use optimized method based on sequence length
#         if self.sequence_length <= 60:
#             try:
#                 unique_integers = random.sample(range(max_possible), self.num_samples)
#                 samples_tensor = torch.tensor(unique_integers, dtype=torch.long, device=self.device)
#                 powers_of_2 = 2**torch.arange(self.sequence_length - 1, -1, -1, device=self.device, dtype=torch.long)
#                 binary_samples = (samples_tensor.unsqueeze(-1) // powers_of_2) % 2
#                 logger.info("Generated samples via integer sampling.")
#             except OverflowError:
#                  logger.warning("Integer range too large for standard sampling, falling back to direct generation.")
#                  binary_samples = self._generate_directly()
#             except MemoryError:
#                  logger.warning("Memory error during integer sampling/conversion, falling back to direct generation.")
#                  binary_samples = self._generate_directly()
#         else:
#             logger.info("Sequence length > 60, using direct generation method.")
#             binary_samples = self._generate_directly()

#         # Shuffle samples
#         shuffled_indices = torch.randperm(binary_samples.size(0), device=self.device)
#         logger.info("Shuffled generated samples.")
#         return binary_samples[shuffled_indices].long()

#     def _generate_directly(self) -> torch.Tensor:
#         """Directly generates unique binary sequences with collision handling."""
#         logger.debug("Starting direct generation of unique samples.")
#         binary_samples_cpu = torch.empty((0, self.sequence_length), dtype=torch.long)
#         attempts = 0
#         max_attempts = 30
#         # Use logging for progress instead of tqdm if running non-interactively or prefer logs
#         # pbar = tqdm(total=self.num_samples, desc="Generating Unique Samples (Direct)", unit="samples")
#         last_logged_count = 0
#         while binary_samples_cpu.size(0) < self.num_samples and attempts < max_attempts:
#             current_count = binary_samples_cpu.size(0)
#             if current_count - last_logged_count >= self.num_samples * 0.1 or last_logged_count==0: # Log progress every 10%
#                  logger.info(f"Direct generation progress: {current_count}/{self.num_samples} samples found.")
#                  last_logged_count = current_count

#             needed = self.num_samples - current_count
#             batch_size = min(max(int(needed * 1.2), 5000), 2**17)
#             new_samples = torch.randint(0, 2, (batch_size, self.sequence_length), dtype=torch.long, device='cpu')
#             combined = torch.cat((binary_samples_cpu, new_samples), dim=0)
#             unique_samples = torch.unique(combined, dim=0)

#             if unique_samples.size(0) == current_count and needed > 0:
#                 attempts += 1
#                 logger.debug(f"No new unique samples found. Attempt {attempts}/{max_attempts}.")
#             else:
#                 attempts = 0 # Reset attempts if progress is made

#             binary_samples_cpu = unique_samples
#             if binary_samples_cpu.size(0) > self.num_samples:
#                 binary_samples_cpu = binary_samples_cpu[:self.num_samples]
#             # pbar.n = binary_samples_cpu.size(0)
#             # pbar.refresh()

#         # pbar.close()
#         final_count = binary_samples_cpu.size(0)
#         logger.info(f"Direct generation finished. Found {final_count}/{self.num_samples} samples.")
#         if final_count < self.num_samples:
#              logger.error(f"Could not generate {self.num_samples} unique samples after {max_attempts} attempts.")
#              raise RuntimeError(f"Could not generate {self.num_samples} unique samples.")
#         logger.info("Generated samples via direct generation.")
#         return binary_samples_cpu.to(self.device)

#     def generate_data(self) -> List[Dict[str, Any]]:
#         """Generates the dataset as a list of dictionaries."""
#         binary_inputs = self._generate_unique_binary_samples()

#         logger.info(f"Calculating outputs using target function '{self.target_function_name}'...")
#         outputs = self.target_function(binary_inputs.to(self.device), self.device)

#         dataset = []
#         logger.info("Formatting dataset...")
#         # Can optionally use tqdm here if desired, but logging might suffice
#         for i in range(self.num_samples):
#             input_str_tensor = np.array( [str(bit.item()) for bit in binary_inputs[i]] )
#             dataset.append({
#                 'Input': input_str_tensor,
#                 'Output': str(outputs[i].item())
#             })
#         logger.info("Data generation complete.")
#         return dataset

class BinaryDataGenerator:
    """
    Generates binary input vectors and corresponding outputs from a target function,
    ensuring an exact 50/50 split between output labels '0' and '1'.
    """
    def __init__(self,
                 target_function_name: str,
                 sequence_length: int,
                 num_samples: int,
                 device: str = 'cpu'):

        if target_function_name not in TARGET_FUNCTIONS:
             logger.error(f"Unknown target function: {target_function_name}. Available: {list(TARGET_FUNCTIONS.keys())}")
             raise ValueError(f"Unknown target function: {target_function_name}.")
        self.target_function = TARGET_FUNCTIONS[target_function_name]
        self.target_function_name = target_function_name
        
        if num_samples <= 0 or num_samples % 2 != 0:
            logger.error(f"num_samples must be a positive, even number for a 50/50 split, but got {num_samples}.")
            raise ValueError("num_samples must be a positive, even number for a 50/50 split.")
        self.num_samples = num_samples

        max_possible = 2**sequence_length
        if num_samples > max_possible:
             logger.error(f"Requested {num_samples} samples, but only {max_possible} unique sequences possible.")
             raise ValueError(f"Requested {num_samples} samples, but only {max_possible} unique sequences are possible.")
        self.sequence_length = sequence_length
        
        self.device = device
        logger.info(f"DataGenerator initialized for target='{target_function_name}', len={sequence_length}, samples={num_samples}, device='{device}'")

    def _generate_balanced_samples(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generates unique binary sequences until a balanced set for each label is found.
        This method uses a "generate, check, and collect" strategy.
        """
        needed_per_label = self.num_samples // 2
        logger.info(f"Generating {needed_per_label} unique samples for label '0' and {needed_per_label} for label '1'...")

        # Tensors to store the final collected samples, kept on CPU for efficient set operations
        samples_0_cpu = torch.empty((0, self.sequence_length), dtype=torch.long)
        samples_1_cpu = torch.empty((0, self.sequence_length), dtype=torch.long)
        
        # Use a set of tuples for fast uniqueness checks of all found samples
        found_samples_set = set()

        attempts = 0
        max_attempts_without_progress = 50 # Safety break

        while samples_0_cpu.size(0) < needed_per_label or samples_1_cpu.size(0) < needed_per_label:
            needed_0 = needed_per_label - samples_0_cpu.size(0)
            needed_1 = needed_per_label - samples_1_cpu.size(0)
            
            # Heuristic for batch size: aim for more than what's needed to account for skewed functions and duplicates
            batch_size = int(max(needed_0 + needed_1, 1000) * 1.5)
            
            # Generate a new batch of random candidates on the target device
            candidates = torch.randint(0, 2, (batch_size, self.sequence_length), dtype=torch.long, device=self.device)
            
            # Calculate outputs for the batch
            outputs = self.target_function(candidates, self.device)

            progress_made = False
            # --- Process candidates for label '0' ---
            if needed_0 > 0:
                indices_0 = (outputs == 0).nonzero(as_tuple=True)[0]
                if indices_0.numel() > 0:
                    candidates_0 = candidates[indices_0].cpu() # Move to CPU for processing
                    unique_new_0 = []
                    for cand in candidates_0:
                        cand_tuple = tuple(cand.tolist())
                        if cand_tuple not in found_samples_set:
                            unique_new_0.append(cand)
                            found_samples_set.add(cand_tuple)
                    
                    if unique_new_0:
                        samples_0_cpu = torch.cat((samples_0_cpu, torch.stack(unique_new_0)), dim=0)
                        progress_made = True

            # --- Process candidates for label '1' ---
            if needed_1 > 0:
                indices_1 = (outputs == 1).nonzero(as_tuple=True)[0]
                if indices_1.numel() > 0:
                    candidates_1 = candidates[indices_1].cpu()
                    unique_new_1 = []
                    for cand in candidates_1:
                        cand_tuple = tuple(cand.tolist())
                        if cand_tuple not in found_samples_set:
                            unique_new_1.append(cand)
                            found_samples_set.add(cand_tuple)

                    if unique_new_1:
                        samples_1_cpu = torch.cat((samples_1_cpu, torch.stack(unique_new_1)), dim=0)
                        progress_made = True
            
            if progress_made:
                attempts = 0
            else:
                attempts += 1

            if attempts > max_attempts_without_progress:
                msg = (f"Failed to generate balanced dataset after {attempts} attempts without progress. "
                       f"Found {samples_0_cpu.size(0)}/{needed_per_label} for label 0 and "
                       f"{samples_1_cpu.size(0)}/{needed_per_label} for label 1. "
                       "The target function might be too skewed for this sequence length.")
                logger.error(msg)
                raise RuntimeError(msg)
            
            logger.debug(f"Progress: {samples_0_cpu.size(0)}/{needed_per_label} (label 0), "
                         f"{samples_1_cpu.size(0)}/{needed_per_label} (label 1)")

        # Trim any excess samples and move to the target device
        final_samples_0 = samples_0_cpu[:needed_per_label].to(self.device)
        final_samples_1 = samples_1_cpu[:needed_per_label].to(self.device)
        
        logger.info(f"Successfully generated {final_samples_0.size(0)} samples for label '0' and {final_samples_1.size(0)} for label '1'.")
        
        return final_samples_0, final_samples_1

    def generate_data(self) -> List[Dict[str, Any]]:
        """Generates the balanced dataset as a list of dictionaries."""
        samples_0, samples_1 = self._generate_balanced_samples()

        # Create corresponding outputs
        outputs_0 = torch.zeros(samples_0.size(0), dtype=torch.long, device=self.device)
        outputs_1 = torch.ones(samples_1.size(0), dtype=torch.long, device=self.device)
        
        # Combine the two sets of data
        all_inputs = torch.cat((samples_0, samples_1), dim=0)
        all_outputs = torch.cat((outputs_0, outputs_1), dim=0)
        
        # Shuffle the combined dataset thoroughly
        logger.info("Shuffling the combined dataset...")
        shuffled_indices = torch.randperm(all_inputs.size(0))
        all_inputs = all_inputs[shuffled_indices]
        all_outputs = all_outputs[shuffled_indices]
        
        dataset = []
        logger.info("Formatting dataset...")
        inputs_cpu = all_inputs.cpu()
        outputs_cpu = all_outputs.cpu()
        
        for i in range(self.num_samples):
            input_np_str = np.array([str(bit.item()) for bit in inputs_cpu[i]])
            dataset.append({
                'Input': input_np_str,
                'Output': str(outputs_cpu[i].item())
            })
            
        logger.info("Data generation complete.")
        return dataset
    

class ScalarSortedDataGenerator:
    """
    Generates 1D input vectors with scalar integer values (0-100).
    The dataset is balanced such that 50% of samples are sorted ascending
    sequences (label 1) and 50% are unsorted sequences (label 0).
    """
    def __init__(self,
                 sequence_length: int,
                 num_samples: int,
                 min_value: int = 0,
                 max_value: int = 100,
                 device: str = 'cpu'):
        """
        Initializes the ScalarSortedDataGenerator.

        Args:
            sequence_length: The length of each input vector. Must be > 1 for
                             unsorted sequences to be possible.
            num_samples: The total number of samples to generate. Must be an even number
                         for a perfect 50/50 split.
            min_value: The minimum scalar value for input elements (inclusive).
            max_value: The maximum scalar value for input elements (inclusive).
            device: The torch device ('cpu' or 'cuda') to perform generation on.
        """
        if num_samples <= 0:
            raise ValueError("num_samples must be positive.")
        if num_samples % 2 != 0:
            # Forcing even numbers for simplicity, as the core requirement is 50/50
            raise ValueError("num_samples must be even for a perfect 50/50 split.")
            # Alternative: Could allow odd and slightly adjust counts like before
            # logger.warning(f"num_samples ({num_samples}) is odd. Generating "
            #              f"{num_samples // 2} unsorted and {num_samples // 2 + 1} sorted samples.")
        if sequence_length <= 1:
             # All sequences of length 0 or 1 are technically sorted.
             raise ValueError("sequence_length must be greater than 1 to generate distinct sorted/unsorted samples.")
        if not (0 <= min_value <= max_value):
             raise ValueError(f"Invalid range: min_value ({min_value}) must be <= max_value ({max_value}) and non-negative.")

        self.sequence_length = sequence_length
        self.num_samples = num_samples
        self.min_value = min_value
        # torch.randint's upper bound is exclusive, so add 1
        self.max_value_exclusive = max_value + 1
        self.device = device
        logger.info(f"ScalarSortedDataGenerator initialized: "
                    f"len={sequence_length}, samples={num_samples}, range=[{min_value},{max_value}], "
                    f"device='{device}'")

    def _generate_sorted_samples(self, num_to_generate: int) -> torch.Tensor:
        """Generates sequences that are guaranteed to be sorted ascending."""
        logger.debug(f"Generating {num_to_generate} sorted samples...")
        # Generate random numbers first
        random_samples = torch.randint(self.min_value, self.max_value_exclusive,
                                      (num_to_generate, self.sequence_length),
                                      dtype=torch.long, device=self.device)
        # Sort each sequence along the sequence dimension (dim=1)
        sorted_samples, _ = torch.sort(random_samples, dim=1)
        logger.debug(f"Finished generating {num_to_generate} sorted samples.")
        return sorted_samples

    def _generate_unsorted_samples(self, num_to_generate: int) -> torch.Tensor:
        """
        Generates sequences that are highly likely to be unsorted.
        For sequence_length > 1, the probability of a random sequence being
        sorted is very low. We generate random sequences and accept them.
        A strict guarantee would require checking and reshuffling, which adds complexity.
        """
        logger.debug(f"Generating {num_to_generate} unsorted samples (via random generation)...")
        # Note: There's a *very* small theoretical chance a random sequence might be
        # sorted, especially for small max_value or short sequence_length.
        # For most practical purposes with sequence_length > 1 and reasonable range,
        # random generation is sufficient to produce unsorted sequences.
        unsorted_samples = torch.randint(self.min_value, self.max_value_exclusive,
                                        (num_to_generate, self.sequence_length),
                                        dtype=torch.long, device=self.device)

        # Optional Strict Check (more computationally expensive):
        # If a strict guarantee is needed, uncomment and adapt this block:
        is_sorted_mask = torch.zeros(num_to_generate, dtype=torch.bool, device=self.device)
        for i in range(num_to_generate):
            is_sorted_mask[i] = torch.equal(unsorted_samples[i], torch.sort(unsorted_samples[i])[0])
        
        while torch.any(is_sorted_mask):
            logger.warning(f"Found {torch.sum(is_sorted_mask)} randomly generated sequences that were sorted. Reshuffling...")
            indices_to_reshuffle = torch.where(is_sorted_mask)[0]
            for idx in indices_to_reshuffle:
                # Simple shuffle: swap two elements
                swap_idx = torch.randperm(self.sequence_length, device=self.device)[:2]
                val1 = unsorted_samples[idx, swap_idx[0]].item()
                unsorted_samples[idx, swap_idx[0]] = unsorted_samples[idx, swap_idx[1]]
                unsorted_samples[idx, swap_idx[1]] = val1
                # Recheck this specific sample - could still be sorted if duplicates exist and were swapped
                if torch.equal(unsorted_samples[idx], torch.sort(unsorted_samples[idx])[0]):
                     # If still sorted (e.g., all elements were identical), generate a new random one
                     unsorted_samples[idx] = torch.randint(self.min_value, self.max_value_exclusive,
                                                            (1, self.sequence_length),
                                                            dtype=torch.long, device=self.device)
        
            # Re-evaluate the mask for the *entire* batch after shuffling/regeneration
            new_is_sorted_mask = torch.zeros(num_to_generate, dtype=torch.bool, device=self.device)
            for i in range(num_to_generate):
                new_is_sorted_mask[i] = torch.equal(unsorted_samples[i], torch.sort(unsorted_samples[i])[0])
            is_sorted_mask = new_is_sorted_mask


        logger.debug(f"Finished generating {num_to_generate} unsorted samples.")
        return unsorted_samples

    def generate_data(self) -> List[Dict[str, Any]]:
        """
        Generates the final balanced dataset as a list of dictionaries.

        Returns:
            A list of dictionaries, where each dictionary represents a sample
            and has keys 'Input' (numpy array of strings) and 'Output' ('1' for sorted, '0' for unsorted).
        """
        num_sorted = self.num_samples // 2
        num_unsorted = self.num_samples - num_sorted # Handles potential future odd num_samples

        logger.info(f"Generating {num_sorted} sorted and {num_unsorted} unsorted samples.")

        # Generate inputs
        sorted_inputs = self._generate_sorted_samples(num_sorted)
        unsorted_inputs = self._generate_unsorted_samples(num_unsorted)

        # Generate corresponding outputs (labels)
        # Label 1 for sorted, Label 0 for unsorted
        sorted_outputs = torch.ones(num_sorted, dtype=torch.long, device=self.device)
        unsorted_outputs = torch.zeros(num_unsorted, dtype=torch.long, device=self.device)

        # Combine sorted and unsorted data
        all_inputs = torch.cat((sorted_inputs, unsorted_inputs), dim=0)
        all_outputs = torch.cat((sorted_outputs, unsorted_outputs), dim=0)

        # Shuffle the combined dataset thoroughly
        logger.info("Shuffling the combined dataset...")
        shuffled_indices = torch.randperm(all_inputs.size(0), device='cpu') # Shuffle on CPU
        all_inputs = all_inputs[shuffled_indices]
        all_outputs = all_outputs[shuffled_indices]
        logger.info("Shuffling complete.")

        dataset = []
        logger.info("Formatting dataset into list of dictionaries...")
        # Move tensors to CPU for numpy conversion if they aren't already
        inputs_tensor_cpu = all_inputs.cpu()
        outputs_tensor_cpu = all_outputs.cpu()

        for i in range(self.num_samples):
            # Convert each input tensor row to a numpy array of strings
            input_np_str = np.array([str(val.item()) for val in inputs_tensor_cpu[i]])
            dataset.append({
                'Input': input_np_str,
                'Output': str(outputs_tensor_cpu[i].item()) # '1' or '0'
            })
            if (i + 1) % max(1, self.num_samples // 10) == 0: # Log progress every 10%
                 logger.debug(f"Formatting progress: {i+1}/{self.num_samples}")

        logger.info(f"Data generation and formatting complete. Returning {len(dataset)} samples.")
        return dataset


class Dyck2DataGenerator:
    """
    Generates sequences for the Dyck-2 language problem (e.g., '()[]').
    Each parenthesis is encoded by 2 bits. The dataset is balanced with 50%
    valid sequences (label 1) and 50% invalid sequences (label 0).
    """
    def __init__(self,
                 sequence_length: int,
                 num_samples: int,
                 device: str = 'cpu'):
        """
        Initializes the Dyck2DataGenerator.

        Args:
            sequence_length: The length of the bit sequence. Must be a multiple of 4.
            num_samples: The total number of samples to generate. Must be an even number.
            device: The torch device ('cpu' or 'cuda') to perform generation on.
        """
        if sequence_length <= 0 or sequence_length % 4 != 0:
            raise ValueError("sequence_length must be a positive multiple of 4 for Dyck-2 (2 bits per parenthesis, even length).")
        if num_samples <= 0 or num_samples % 2 != 0:
            raise ValueError("num_samples must be a positive, even number for a perfect 50/50 split.")

        self.sequence_length = sequence_length
        self.paren_seq_length = sequence_length // 2 # Length of the string of '()[]'
        self.num_samples = num_samples
        self.device = device

        # Mappings from bits to parentheses and vice-versa
        self.paren_map = {"00": "(", "01": ")", "10": "[", "11": "]"}
        self.paren_to_bit_str = {v: k for k, v in self.paren_map.items()}
        self.open_to_close = {'(': ')', '[': ']'}
        self.close_to_open = {v: k for k, v in self.open_to_close.items()}

        logger.info(f"Dyck2DataGenerator initialized: len={sequence_length}, samples={num_samples}, device='{device}'")

    def _is_valid_paren_seq(self, paren_seq: str) -> bool:
        """Checks if a sequence of parentheses is a valid Dyck-2 sequence."""
        stack = []
        for char in paren_seq:
            if char in self.open_to_close: # It's an open bracket
                stack.append(char)
            elif char in self.close_to_open: # It's a close bracket
                if not stack or stack.pop() != self.close_to_open[char]:
                    return False # Mismatch or closing with empty stack
            else:
                return False # Invalid character
        return not stack # Valid if stack is empty at the end

    def _generate_one_valid_paren_seq(self) -> str:
        """Generates a single valid parenthesis sequence of the correct length."""
        stack, seq = [], []
        # The final sequence will have self.paren_seq_length characters
        while (len(stack) + len(seq)) < self.paren_seq_length:
            o, c = random.choice(list(self.open_to_close.items()))
            # Decide whether to add an opening bracket or close an existing one
            # We must add an opening bracket if the stack is empty
            if not stack or random.random() < 0.5:
                seq.append(o)
                stack.append(c)
            else: # Close an existing bracket from the stack
                seq.append(stack.pop())
        # After the main loop, append all remaining closing brackets from the stack
        while stack:
            seq.append(stack.pop())
        return "".join(seq)

    def _generate_valid_samples(self, num_to_generate: int) -> torch.Tensor:
        """Generates a batch of valid Dyck-2 sequences as bit tensors."""
        logger.debug(f"Generating {num_to_generate} valid Dyck-2 samples...")
        valid_samples = []
        for _ in tqdm(range(num_to_generate), desc="Generating Valid Dyck-2", unit="seq", leave=False):
            paren_seq = self._generate_one_valid_paren_seq()
            bit_string = "".join(self.paren_to_bit_str[c] for c in paren_seq)
            valid_samples.append(torch.tensor([int(b) for b in bit_string], dtype=torch.long))
        return torch.stack(valid_samples, dim=0).to(self.device)

    def _generate_invalid_samples(self, num_to_generate: int) -> torch.Tensor:
        """
        Generates invalid sequences by creating valid ones and then corrupting them.
        This ensures the invalid sequences are "near-misses" and are guaranteed to be invalid.
        """
        logger.debug(f"Generating {num_to_generate} invalid Dyck-2 samples...")
        invalid_samples = []
        pbar = tqdm(total=num_to_generate, desc="Generating Invalid Dyck-2", unit="seq", leave=False)
        while len(invalid_samples) < num_to_generate:
            # 1. Generate a valid sequence as a base
            valid_paren_seq = self._generate_one_valid_paren_seq()
            bits = list(''.join(self.paren_to_bit_str[c] for c in valid_paren_seq))

            # 2. Corrupt it by flipping a few bits
            num_flips = random.randint(1, 5)
            positions_to_flip = random.sample(range(self.sequence_length), k=min(num_flips, self.sequence_length))
            for pos in positions_to_flip:
                bits[pos] = '1' if bits[pos] == '0' else '0'
            corrupted_bit_str = "".join(bits)

            # 3. Verify it's actually invalid.
            bit_pairs = [corrupted_bit_str[i:i+2] for i in range(0, self.sequence_length, 2)]
            # Check if all resulting 2-bit pairs are valid parenthesis encodings
            if all(pair in self.paren_map for pair in bit_pairs):
                paren_seq_from_corrupted = "".join([self.paren_map[p] for p in bit_pairs])
                # If it's a valid paren sequence AND it is NOT a valid Dyck-2 sequence, we accept it.
                if not self._is_valid_paren_seq(paren_seq_from_corrupted):
                    invalid_samples.append(torch.tensor([int(b) for b in corrupted_bit_str], dtype=torch.long))
                    pbar.update(1)
                # else: a valid sequence was accidentally created, so we retry.
            else:
                # If the bit pairs don't even form valid parentheses, it's guaranteed to be invalid.
                invalid_samples.append(torch.tensor([int(b) for b in corrupted_bit_str], dtype=torch.long))
                pbar.update(1)
        pbar.close()
        return torch.stack(invalid_samples, dim=0).to(self.device)

    def generate_data(self) -> List[Dict[str, Any]]:
        """
        Generates the final balanced dataset as a list of dictionaries.

        Returns:
            A list of dictionaries, each with 'Input' (numpy array of bit strings)
            and 'Output' ('1' for valid, '0' for invalid).
        """
        num_valid = self.num_samples // 2
        num_invalid = self.num_samples - num_valid
        logger.info(f"Generating {num_valid} valid and {num_invalid} invalid Dyck-2 samples.")

        # Generate inputs
        valid_inputs = self._generate_valid_samples(num_valid)
        invalid_inputs = self._generate_invalid_samples(num_invalid)

        # Generate corresponding outputs (labels)
        valid_outputs = torch.ones(num_valid, dtype=torch.long, device=self.device)
        invalid_outputs = torch.zeros(num_invalid, dtype=torch.long, device=self.device)

        # Combine and shuffle
        all_inputs = torch.cat((valid_inputs, invalid_inputs), dim=0)
        all_outputs = torch.cat((valid_outputs, invalid_outputs), dim=0)

        logger.info("Shuffling the combined dataset...")
        shuffled_indices = torch.randperm(all_inputs.size(0))
        all_inputs = all_inputs[shuffled_indices]
        all_outputs = all_outputs[shuffled_indices]
        logger.info("Shuffling complete.")

        # Format into the required list of dictionaries
        dataset = []
        logger.info("Formatting dataset into list of dictionaries...")
        inputs_cpu = all_inputs.cpu()
        outputs_cpu = all_outputs.cpu()

        for i in range(self.num_samples):
            input_np_str = np.array([str(val.item()) for val in inputs_cpu[i]])
            dataset.append({
                'Input': input_np_str,
                'Output': str(outputs_cpu[i].item())
            })

        logger.info(f"Data generation and formatting complete. Returning {len(dataset)} samples.")
        return dataset


def _find_all_occurrences(sequence: torch.Tensor, pattern: torch.Tensor) -> List[int]:
    """Finds all starting indices of a pattern in a sequence."""
    seq_len, pattern_len = len(sequence), len(pattern)
    occurrences = []
    for i in range(seq_len - pattern_len + 1):
        if torch.equal(sequence[i:i+pattern_len], pattern):
            occurrences.append(i)
    return occurrences

class PatternBasedDataGenerator:
    """
    Generates a balanced and unique dataset of binary sequences.

    This generator employs efficient, batch-based methods. For sequences with the
    pattern, it strategically inserts the pattern into random data. For sequences
    without the pattern, it uses an intelligent "generate and repair" strategy
    to ensure the pattern is absent, which is significantly more efficient than
    a simple "generate and discard" approach.
    """
    def __init__(self,
                 sequence_length: int,
                 total_size: int,
                 device: str = 'cpu'):
        """
        Initializes the PatternBasedDataGenerator.

        Args:
            sequence_length (int): The length of each binary sequence.
            total_size (int): The total number of samples to generate. Must be an even number.
            device (str): The device for tensor operations ('cpu' or 'cuda').
        """
        if total_size % 2 != 0:
            logger.error("total_size must be an even number for a balanced dataset.")
            raise ValueError("total_size must be an even number for a balanced dataset.")
        
        self.sequence_length = sequence_length
        self.total_size = total_size
        self.device = device
        logger.info(
            f"PatternBasedDataGenerator initialized for len={self.sequence_length}, "
            f"total_size={self.total_size}, device='{self.device}'"
        )
        
    # def generate_data(self, pattern_string: str = '00111111') -> List[Dict[str, Any]]:
    def generate_data(self, pattern_string: str = '10101010') -> List[Dict[str, Any]]:
        """
        The main method to generate the dataset based on a given pattern.

        Args:
            pattern_string (str): A string of '0's and '1's representing the pattern.

        Returns:
            List[Dict[str, Any]]: A shuffled list of dictionaries, where each dictionary
                                 represents a sample with 'Input' and 'Output' keys.
                                 'Output' is '1' if the pattern is present, '0' otherwise.
        """
        if not pattern_string:
            raise ValueError("pattern_string cannot be empty.")
        pattern_len = len(pattern_string)
        if pattern_len > self.sequence_length:
            raise ValueError(
                f"Pattern length ({pattern_len}) cannot exceed sequence length ({self.sequence_length})."
            )

        logger.info(f"Starting data generation for pattern: '{pattern_string}'")
        
        try:
            pattern_tensor = torch.tensor([int(bit) for bit in pattern_string], dtype=torch.long, device=self.device)
        except ValueError:
            raise ValueError("pattern_string must contain only '0's and '1's.")

        num_with_pattern = self.total_size // 2
        num_without_pattern = self.total_size - num_with_pattern

        # --- Generate samples WITH and WITHOUT the pattern ---
        samples_with = self._generate_samples_with_pattern(num_with_pattern, pattern_tensor)
        
        # Pass samples_with to ensure dataset-wide uniqueness
        samples_without = self._generate_samples_without_pattern(num_without_pattern, pattern_tensor, samples_with)

        # --- Combine, Format, and Shuffle ---
        return self._format_and_shuffle_dataset(samples_with, samples_without)

    def _generate_samples_with_pattern(self, num_samples: int, pattern_tensor: torch.Tensor) -> Set[Tuple[int, ...]]:
        """Generates a set of unique samples containing the pattern in one go."""
        logger.info(f"Generating {num_samples} unique samples WITH the pattern...")
        unique_samples = set()
        attempts = 0
        max_attempts = 30
        pattern_len = len(pattern_tensor)

        while len(unique_samples) < num_samples and attempts < max_attempts:
            needed = num_samples - len(unique_samples)
            batch_size = int(needed * 1.5) + 10 
            
            sequences = torch.randint(0, 2, (batch_size, self.sequence_length), dtype=torch.long, device=self.device)
            
            insertion_indices = torch.randint(0, self.sequence_length - pattern_len + 1, (batch_size,))
            
            for i in range(batch_size):
                sequences[i, insertion_indices[i] : insertion_indices[i] + pattern_len] = pattern_tensor
            
            new_tuples = {tuple(seq.tolist()) for seq in sequences}
            
            if len(new_tuples.union(unique_samples)) > len(unique_samples):
                attempts = 0
            else:
                attempts +=1

            unique_samples.update(new_tuples)

        if len(unique_samples) < num_samples:
            raise RuntimeError(f"Failed to generate {num_samples} unique samples WITH pattern. Try different parameters.")

        return set(random.sample(list(unique_samples), num_samples))

    def _generate_samples_without_pattern(self, num_samples: int, pattern_tensor: torch.Tensor, existing_samples: Set[Tuple[int, ...]]) -> Set[Tuple[int, ...]]:
        """Generates a set of unique samples guaranteed NOT to contain the pattern."""
        logger.info(f"Generating {num_samples} unique samples WITHOUT the pattern (intelligent method)...")
        unique_samples = set()
        attempts = 0
        max_attempts = 30
        pattern_len = len(pattern_tensor)

        while len(unique_samples) < num_samples and attempts < max_attempts:
            needed = num_samples - len(unique_samples)
            batch_size = int(needed * 1.5) + 10
            
            sequences = torch.randint(0, 2, (batch_size, self.sequence_length), dtype=torch.long, device=self.device)
            
            repaired_sequences = set()
            for seq in sequences: 
                while True:
                    occurrences = _find_all_occurrences(seq, pattern_tensor)
                    if not occurrences:
                        break
                    
                    idx_to_break = random.choice(occurrences)
                    bit_to_flip = random.randint(0, pattern_len - 1)
                    seq[idx_to_break + bit_to_flip] = 1 - seq[idx_to_break + bit_to_flip]
                
                seq_tuple = tuple(seq.tolist())
                if seq_tuple not in existing_samples:
                    repaired_sequences.add(seq_tuple)
            
            if len(repaired_sequences.union(unique_samples)) > len(unique_samples):
                attempts = 0
            else:
                attempts += 1
            
            unique_samples.update(repaired_sequences)

        if len(unique_samples) < num_samples:
            raise RuntimeError(f"Failed to generate {num_samples} unique samples WITHOUT pattern.")

        return set(random.sample(list(unique_samples), num_samples))
        
    def _format_and_shuffle_dataset(self, samples_with: Set, samples_without: Set) -> List[Dict[str, Any]]:
        """Converts sets of samples to the final shuffled list of dictionaries."""
        logger.info("Formatting and shuffling final dataset...")
        
        samples_with_tensor = torch.tensor(list(samples_with), dtype=torch.long, device=self.device)
        samples_without_tensor = torch.tensor(list(samples_without), dtype=torch.long, device=self.device)
        
        labels_with = torch.ones(samples_with_tensor.size(0), dtype=torch.long, device=self.device)
        labels_without = torch.zeros(samples_without_tensor.size(0), dtype=torch.long, device=self.device)

        all_samples = torch.cat((samples_with_tensor, samples_without_tensor), dim=0)
        all_labels = torch.cat((labels_with, labels_without), dim=0)

        indices = torch.randperm(all_samples.size(0), device=self.device)
        shuffled_samples = all_samples[indices]
        shuffled_labels = all_labels[indices]
        
        dataset = []
        for i in range(shuffled_samples.size(0)):
            # CORRECTED LINE: Convert the input to a NumPy array of strings.
            input_numpy_array = np.array([str(bit.item()) for bit in shuffled_samples[i]])
            dataset.append({
                'Input': input_numpy_array,
                'Output': str(shuffled_labels[i].item())
            })
            
        logger.info("Pattern-based data generation complete.")
        return dataset


class PalindromeDataGenerator:
    """Generates binary input vectors and corresponding outputs for a palindrome task.

    The generated dataset contains 50% palindromic sequences and 50%
    non-palindromic sequences.
    """
    def __init__(self,
                 sequence_length: int,
                 num_samples: int,
                 device: str = 'cpu'):

        if num_samples % 2 != 0:
            logger.error(f"Number of samples must be even for a 50/50 split, but got {num_samples}.")
            raise ValueError("Number of samples must be even for palindrome generation.")

        self.sequence_length = sequence_length
        self.num_samples = num_samples
        self.device = device
        logger.info(f"DataGenerator initialized for target='is_palindrome', len={sequence_length}, samples={num_samples}, device='{device}'")

    def _generate_split_samples(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generates a dataset with 50% palindromes and 50% non-palindromes."""
        half_samples = self.num_samples // 2

        # 1. Generate Palindromes
        # A palindrome is defined by its first half.
        # Length of the first half, handling both odd and even sequence lengths
        logger.info(f"Generating {half_samples} palindromic samples...")
        half_len = (self.sequence_length + 1) // 2

        # Generate random first halves
        first_halves = torch.randint(0, 2, size=(half_samples, half_len), device=self.device, dtype=torch.long)

        # Create the second half by reversing the first part (excluding the middle element if odd)
        second_halves_reversed = torch.flip(first_halves[:, :self.sequence_length // 2], dims=[1])

        # Concatenate to form full palindromes
        palindromes = torch.cat([first_halves, second_halves_reversed], dim=1)
        logger.info("Generated palindromic samples.")


        # 2. Generate Non-Palindromes
        logger.info(f"Generating {half_samples} non-palindromic samples...")
        # An easy way to guarantee a non-palindrome is to create a palindrome and flip a single bit in the first half.
        # This ensures the symmetry is broken.
        non_palindromes = palindromes.clone()  # Start with the same palindromes

        # Pick a random index in the first half of each sequence to flip
        row_indices = torch.arange(half_samples, device=self.device)
        col_indices_to_flip = torch.randint(0, half_len, size=(half_samples,), device=self.device)

        # Flip the bits at the chosen locations (0 becomes 1, 1 becomes 0)
        non_palindromes[row_indices, col_indices_to_flip] = 1 - non_palindromes[row_indices, col_indices_to_flip]
        logger.info("Generated non-palindromic samples.")


        # 3. Combine inputs and create corresponding outputs
        binary_inputs = torch.cat([palindromes, non_palindromes], dim=0)
        # Outputs: 1 for palindrome, 0 for non-palindrome
        palindrome_outputs = torch.ones(half_samples, dtype=torch.long, device=self.device)
        non_palindrome_outputs = torch.zeros(half_samples, dtype=torch.long, device=self.device)
        outputs = torch.cat([palindrome_outputs, non_palindrome_outputs], dim=0)


        # 4. Shuffle both inputs and outputs together to mix the dataset
        shuffled_indices = torch.randperm(binary_inputs.size(0), device=self.device)
        logger.info("Shuffled generated samples.")

        return binary_inputs[shuffled_indices], outputs[shuffled_indices]

    def generate_data(self) -> List[Dict[str, Any]]:
        """Generates the dataset as a list of dictionaries."""
        binary_inputs, outputs = self._generate_split_samples()

        dataset = []
        logger.info("Formatting dataset...")
        for i in range(self.num_samples):
            input_str_tensor = np.array([str(bit.item()) for bit in binary_inputs[i]])
            dataset.append({
                'Input': input_str_tensor,
                'Output': str(outputs[i].item())
            })
        logger.info("Data generation complete.")
        return dataset
