# gen_sigma.py — dataset generators with sigma-encoding inputs
# Adds semantic negatives (well-formed but arithmetically wrong) alongside malformed.

from __future__ import annotations
from dataclasses import dataclass
from typing import List, Tuple, Optional
import random, os
from tqdm import tqdm

# ============================ basic utils ====================================

def bin_canon(n: int) -> str:
    """Canonical binary string: 0 -> '0'; >0 -> no leading zeros."""
    return '0' if n == 0 else bin(n)[2:]

@dataclass
class Sample:
    s_in: str    # sigma-encoded input (comma-separated integers)
    s_can: str   # canonical word (what the KCM is evaluated on)

def _write_dataset(dirpath: str, split_name: str,
                   samples: List[Sample], kcm, alphabet: List[str]) -> None:
    os.makedirs(dirpath, exist_ok=True)
    with open(os.path.join(dirpath, f"input{split_name}.txt"), "w") as fi, \
         open(os.path.join(dirpath, f"target{split_name}.txt"), "w") as ft:
        for smp in samples:
            fi.write(smp.s_in + "\n")
            ft.write(kcm.output_generator(smp.s_can, alphabet) + "\n")

# =========================== sigma and β^{-1} ====================================

def beta_inv(mask: str) -> int:
    """
    A canonical β^{-1}: for any 0/1 mask b, x = int('10' + b, 2) has β(x) = b.
    (This is exactly the trick used in the paper's encoding.)
    """
    if any(ch not in '01' for ch in mask):
        raise ValueError("beta_inv expects a 0/1 mask")
    return int('10' + mask, 2)

def sigma_numbers_from_word(w: str, alphabet: List[str]) -> List[int]:
    """
    sigma(w) for arbitrary alphabet sigma. For each a ∈ sigma, build the indicator mask over w
    (1 where char==a, else 0), then map the mask through β^{-1}.
    Returns the list of integers [x_a]_a.
    """
    xs = []
    for a in alphabet:
        mask = ''.join('1' if ch == a else '0' for ch in w)
        xs.append(beta_inv(mask))
    return xs

def sigma_input_string(w: str, alphabet: List[str], sep: str = ',') -> str:
    """Serialize sigma(w) as 'x0<sep>x1<sep>...' (sep must NOT be in sigma)."""
    xs = sigma_numbers_from_word(w, alphabet)
    return sep.join(str(x) for x in xs)

# ====================== malformed word helpers (for negatives) ===============

def _malformed_word_01(L: int) -> str:
    return ''.join(random.choice('01') for _ in range(max(1, L)))

def _malformed_word_01_slash(total_lo: int, total_hi: int) -> str:
    """Malformed word over {0,1,'/'}."""
    L = random.randint(max(1, total_lo), max(1, total_hi))
    mode = random.choice(["missing_slash", "extra_slash", "lead0_piece", "empty_side", "random01"])
    def rand_pos(len_hint: int):
        bl = max(1, len_hint)
        return '1' + ''.join(random.choice('01') for _ in range(max(0, bl - 1)))
    if mode == "missing_slash":
        return ''.join(random.choice('01') for _ in range(L))
    if mode == "extra_slash":
        if L < 7: L = 7
        a = rand_pos(L//3); b = rand_pos(L//3)
        c = rand_pos(max(1, L - len(a) - len(b) - 2))
        return a + '/' + b + '/' + c + '/'
    if mode == "lead0_piece":
        if L < 7: L = 7
        a = "0" + ''.join(random.choice('01') for _ in range(max(0, L//3 - 1)))
        b = rand_pos(L//3)
        c = rand_pos(max(1, L - len(a) - len(b) - 2))
        return a + '/' + b + '/' + c
    if mode == "empty_side":
        if L < 3: L = 3
        choice = random.choice([0,1,2])
        if choice == 0:
            b = rand_pos(L//2); c = rand_pos(max(1, L - len(b) - 2))
            return '/' + b + '/' + c
        elif choice == 1:
            a = rand_pos(L//2); c = rand_pos(max(1, L - len(a) - 2))
            return a + '//' + c
        else:
            a = rand_pos(L//2); b = rand_pos(max(1, L - len(a) - 2))
            return a + '/' + b + '/'
    return ''.join(random.choice(['0','1','/']) for _ in range(L))

# =========================== length helpers (canonical) ======================

def length_add_can(i: int, j: int) -> int:
    k = i + j
    return len(bin_canon(i)) + 1 + len(bin_canon(j)) + 1 + len(bin_canon(k))

def length_mul_can(i: int, j: int) -> int:
    k = i * j
    return len(bin_canon(i)) + 1 + len(bin_canon(j)) + 1 + len(bin_canon(k))

def length_gcd_can(i: int, j: int) -> int:
    import math
    k = math.gcd(i, j)
    return len(bin_canon(i)) + 1 + len(bin_canon(j)) + 1 + len(bin_canon(k))

def length_exp_can(a: int, i: int, b: int) -> int:
    """Length of canonical word: bin(a)/bin(i)/bin(b)/bin(result) where result = a^i * b^(2^i)"""
    try:
        if i > 50:  # Prevent overflow
            return float('inf')
        two_power_i = 2 ** i
        if i > 30 or (a > 0 and two_power_i > 60):
            return float('inf') 
        a_power_i = pow(a, i) if i < 100 else float('inf')
        b_power_two_i = pow(b, two_power_i) if two_power_i < 100 else float('inf')
        if a_power_i == float('inf') or b_power_two_i == float('inf'):
            return float('inf')
        result = a_power_i * b_power_two_i
        if result == float('inf'):
            return float('inf')
        return len(bin_canon(a)) + 1 + len(bin_canon(i)) + 1 + len(bin_canon(b)) + 1 + len(bin_canon(result))
    except (OverflowError, ValueError):
        return float('inf')

def length_dvd_can(w: int, v: int) -> int:
    return len(bin_canon(w)) + 1 + len(bin_canon(v))

# ========================= band samplers (pairs) =============================

def _draw_pos(bits: int) -> int:
    """Random positive with exact bit-length (>=1)."""
    if bits <= 1: return 1
    lo = 1 << (bits - 1)
    hi = (1 << bits) - 1
    return random.randint(lo, hi)

def _pairs_for_band(length_fn, lo: int, hi: int, count: int,
                    allow_zero_i: bool = True, allow_zero_j: bool = True) -> List[Tuple[int,int]]:
    """
    Generic pair generator that tries to hit canonical total length in [lo,hi].
    Duplicates allowed (we top up if needed).
    """
    pairs: List[Tuple[int,int]] = []
    attempts = 0
    max_attempts = count * 60
    while len(pairs) < count and attempts < max_attempts:
        attempts += 1
        # choose rough bit budgets
        bi = random.randint(0, max(1, hi // 3))
        bj = random.randint(0, max(1, hi // 3))
        def maybe_zero(bits: int, allow_zero: bool) -> int:
            if allow_zero and random.random() < 0.15:
                return 0
            return _draw_pos(max(1, bits))
        i = maybe_zero(bi, allow_zero_i)
        j = maybe_zero(bj, allow_zero_j)
        L = length_fn(i, j)
        if lo <= L <= hi:
            pairs.append((i, j))
    # top up by resampling existing ones (duplicates OK)
    while len(pairs) < count and pairs:
        pairs.append(random.choice(pairs))
    return pairs[:count]

# ============================ helpers: semantic negatives =====================

def _perturb_nonneg(x: int) -> int:
    """Return a different nonnegative integer, small perturbation."""
    if x == 0:
        return 1
    for d in (1, -1, 2, -2, 3, -3):
        y = x + d
        if y >= 0 and y != x:
            return y
    return x + 1

# ============================ ADD (sigma inputs) ==============================

def training_set_add(max_num: int, malformed_ratio: float = 0.15, semantic_neg_ratio: float = 0.15):
    from kcm_binary import build_add_binary_kcm
    auto = build_add_binary_kcm()
    alphabet = ['0', '1', '/']

    def REL(i, j): return i + j

    def build_from_pairs(pairs, lo, hi, desc):
        S: List[Sample] = []
        malformed = int(len(pairs) * malformed_ratio / (1 - malformed_ratio))
        semnegs = int(len(pairs) * semantic_neg_ratio)

        # Positives
        for i, j in tqdm(pairs, desc=f"{desc} valid"):
            k = REL(i, j)
            s_can = f"{bin_canon(i)}/{bin_canon(j)}/{bin_canon(k)}"
            s_in  = sigma_input_string(s_can, alphabet, sep=',')
            S.append(Sample(s_in, s_can))

        # Semantic negatives
        rng = random.Random(101)
        for _ in tqdm(range(semnegs), desc=f"{desc} semantic-neg"):
            i, j = rng.choice(pairs)
            k_bad = _perturb_nonneg(REL(i, j))
            s_can_bad = f"{bin_canon(i)}/{bin_canon(j)}/{bin_canon(k_bad)}"
            s_in_bad  = sigma_input_string(s_can_bad, alphabet, sep=',')
            S.append(Sample(s_in_bad, s_can_bad))

        # Malformed negatives
        for _ in tqdm(range(malformed), desc=f"{desc} malformed"):
            bad = _malformed_word_01_slash(lo, hi)
            S.append(Sample(sigma_input_string(bad, alphabet, sep=','), bad))

        random.shuffle(S); return S

    short  = _pairs_for_band(length_add_can, 1, 100, max_num, True, True)
    all_short = build_from_pairs(short, 1, 100, "add-short")
    split = int(0.8 * len(all_short)); train, val0 = all_short[:split], all_short[split:]

    val1_pairs = _pairs_for_band(length_add_can, 101, 200, max(1, len(val0)), True, True)
    val2_pairs = _pairs_for_band(length_add_can, 201,500, max(1, len(val0)), True, True)
    val1 = build_from_pairs(val1_pairs, 101, 200, "add-val1")
    val2 = build_from_pairs(val2_pairs, 201,500, "add-val2")

    _write_dataset("dataset/add", "", train, auto, alphabet)
    _write_dataset("dataset/add", "_val0", val0, auto, alphabet)
    _write_dataset("dataset/add", "_val1", val1, auto, alphabet)
    _write_dataset("dataset/add", "_val2", val2, auto, alphabet)

# ============================ MUL (sigma inputs) ==============================
def training_set_mul(max_num: int,
                     malformed_ratio: float = 0.1,  # Reduced from 0.2
                     semantic_neg_ratio: float = 0.05):  # Reduced from 0.2
    """
    Improved MUL dataset with better learning characteristics:
      - Smaller numbers for easier learning
      - Reduced negative examples for better balance
      - More reasonable validation difficulty progression
    """
    from kcm_binary import build_mul_binary_kcm
    auto = build_mul_binary_kcm()
    alphabet = ['0', '1', '/']

    def nums_for_band(lo: int, hi: int, count: int):
        """Generate pairs with more reasonable number sizes for the target length range."""
        xs, attempts = [], 0
        rng = random.Random(2025)
        max_attempts = count * 100
        
        while len(xs) < count and attempts < max_attempts:
            attempts += 1
            
            if hi >= 32:  # For longer lengths, use larger numbers
                # Target 6-8 bit numbers for products in 30-34 range
                L_a = rng.randint(6, 8)  # 64-255
                L_b = rng.randint(6, 8)  # 64-255
            elif hi >= 25:  # Medium lengths
                L_a = rng.randint(4, 7)  # 16-127
                L_b = rng.randint(4, 7)  
            else:  # Short lengths
                L_a = rng.randint(2, min(5, max(2, hi//2)))  # 4-31
                L_b = rng.randint(2, min(5, max(2, hi//2)))
            
            a = rng.randint(1 << (L_a - 1), (1 << L_a) - 1) if L_a > 0 else 1
            b = rng.randint(1 << (L_b - 1), (1 << L_b) - 1) if L_b > 0 else 1
            
            # Check if product length is reasonable
            c = a * b
            L_total = length_mul_can(a, b)
            
            if lo <= L_total <= hi:
                xs.append((a, b))
        
        # Fill with duplicates if needed
        while len(xs) < count and xs:
            xs.append(rng.choice(xs))
            
        return xs[:count]

    def build_from_pairs(pairs, lo: int, hi: int, desc: str):
        S: List[Sample] = []
        # Reduce negative examples
        malformed = int(len(pairs) * malformed_ratio / (1 + malformed_ratio))
        semnegs = int(len(pairs) * semantic_neg_ratio)

        # --- Positives ---
        for a, b in tqdm(pairs, desc=f"{desc} sigma-valid"):
            c = a * b
            s_can = f"{bin_canon(a)}/{bin_canon(b)}/{bin_canon(c)}"
            s_in  = sigma_input_string(s_can, alphabet, sep=',')
            S.append(Sample(s_in, s_can))

        # --- Semantic negatives (smaller perturbations) ---
        rng = random.Random(404)
        for _ in tqdm(range(semnegs), desc=f"{desc} semantic-neg"):
            a, b = rng.choice(pairs)
            c_true = a * b
            # Use smaller perturbations for more realistic learning
            if c_true <= 10:
                delta = rng.choice([-1, 1])
            else:
                max_delta = max(2, min(10, c_true//10))
                delta = rng.choice([-1, 1, rng.randint(2, max_delta)])
            c_bad = max(1, c_true + delta)
            if c_bad == c_true:
                c_bad += 1
            s_can_bad = f"{bin_canon(a)}/{bin_canon(b)}/{bin_canon(c_bad)}"
            s_in_bad  = sigma_input_string(s_can_bad, alphabet, sep=',')
            S.append(Sample(s_in_bad, s_can_bad))

        # --- Malformed (fewer and shorter) ---
        for _ in tqdm(range(malformed), desc=f"{desc} malformed"):
            bad = _malformed_word_01_slash(lo, min(hi, 50))  # Cap malformed length
            S.append(Sample(sigma_input_string(bad, alphabet, sep=','), bad))

        random.shuffle(S)
        return S

    # --- More reasonable splits ---
    short = nums_for_band(1,50, max_num)  # Easy: length 8-25
    all_short = build_from_pairs(short, 1,50, "mul-short") 
    split = int(0.8 * len(all_short))
    train, val0 = all_short[:split], all_short[split:]

    # Validation with achievable difficulty progression  
    val1 = build_from_pairs(nums_for_band(51,100, max(1, len(val0))), 51,100, "mul-val1")  # Medium
    val2 = build_from_pairs(nums_for_band(101,150, max(1, len(val0))), 101,150, "mul-val2")  # Hard but achievable

    _write_dataset("dataset/mul", "", train, auto, alphabet)
    _write_dataset("dataset/mul", "_val0", val0, auto, alphabet)
    _write_dataset("dataset/mul", "_val1", val1, auto, alphabet)
    _write_dataset("dataset/mul", "_val2", val2, auto, alphabet)



# ============================ GCD (sigma inputs) ==============================
def training_set_gcd(
    max_num: int,
    malformed_ratio: float = 0.1,
    semantic_neg_ratio: float = 0.1
):
    """
    GCD dataset (binary encoding):
      - Canonical words: a/b/c where c = gcd(a, b)
      - Train: 1..100 bits (total length)
      - Val0: holdout from train band
      - Val1: 101..200 bits
      - Val2: 201..300 bits
      - Semantic negatives: plausible but wrong GCDs
    """
    import math
    from kcm_binary import build_gcd_binary_kcm
    auto = build_gcd_binary_kcm()
    alphabet = ['0', '1', '/']

    def random_num_with_len(L: int) -> int:
        """Random positive integer with exactly bit-length L (L>=1)."""
        L = max(1, L)
        return random.randint(1 << (L - 1), (1 << L) - 1)

    def sample_pairs(lo, hi, count):
        """Sample GCD pairs with total canonical length in [lo, hi]."""
        pairs = []
        attempts = 0
        max_attempts = count * 500  # Much higher limit for difficult ranges
        
        # For very large ranges (like 201-300), use a completely different strategy
        if lo > 150:
            # Generate pairs by targeting specific string lengths directly
            while len(pairs) < count and attempts < max_attempts:
                attempts += 1
                
                # Target a specific length in the range
                target_len = random.randint(lo, hi)
                
                # Work backwards: estimate operand bit sizes needed
                # Format: "a_bits/b_bits/gcd_bits" so roughly target_len/3 per part
                est_bits_per_part = max(10, target_len // 4)
                
                La = random.randint(est_bits_per_part - 5, est_bits_per_part + 15)
                Lb = random.randint(est_bits_per_part - 5, est_bits_per_part + 15)
                La = max(1, min(La, 60))  # Cap at reasonable size
                Lb = max(1, min(Lb, 60))
                
                a = random_num_with_len(La)
                b = random_num_with_len(Lb)
                k = math.gcd(a, b)
                
                # Check total canonical length
                s_can = f"{bin_canon(a)}/{bin_canon(b)}/{bin_canon(k)}"
                L = len(s_can)
                
                if lo <= L <= hi:
                    pairs.append((a, b))
                    
                # If we're having trouble, try creating pairs with specific GCD
                elif len(pairs) < count // 2 and attempts > count * 10:
                    # Force a specific GCD size to control total length
                    gcd_bits = max(1, (target_len - 2 * est_bits_per_part - 2) // 1)
                    if gcd_bits > 0 and gcd_bits <= 30:
                        k = random_num_with_len(gcd_bits)
                        # Create multiples
                        m1 = random.randint(2, 20)
                        m2 = random.randint(2, 20)
                        a = k * m1
                        b = k * m2
                        
                        s_can = f"{bin_canon(a)}/{bin_canon(b)}/{bin_canon(k)}"
                        L = len(s_can)
                        if lo <= L <= hi:
                            pairs.append((a, b))
        else:
            # Original strategy for smaller ranges
            while len(pairs) < count and attempts < max_attempts:
                attempts += 1
                # Sample operand bit-lengths
                La = random.randint(1, min(50, hi // 3))  # Reasonable operand size
                Lb = random.randint(1, min(50, hi // 3))
                
                a = random_num_with_len(La)
                b = random_num_with_len(Lb)
                k = math.gcd(a, b)
                
                # Check total canonical length
                s_can = f"{bin_canon(a)}/{bin_canon(b)}/{bin_canon(k)}"
                L = len(s_can)
                
                if lo <= L <= hi:
                    pairs.append((a, b))
        
        # If we still don't have enough, generate minimal samples
        while len(pairs) < count:
            if lo > 150:
                # For large ranges, create a deterministic large pair
                base_bits = max(10, lo // 5)
                a = random_num_with_len(base_bits)
                b = random_num_with_len(base_bits)
                # Ensure they share a common factor
                gcd_val = random.randint(2, min(a, b))
                a = (a // gcd_val) * gcd_val if a >= gcd_val else a
                b = (b // gcd_val) * gcd_val if b >= gcd_val else b
                pairs.append((max(a, 1), max(b, 1)))
            else:
                # Duplicate existing pairs for smaller ranges
                if pairs:
                    pairs.append(random.choice(pairs))
                else:
                    # Fallback: create a simple pair
                    pairs.append((1, 1))
            
        return pairs[:count]

    def build_from_pairs(pairs, lo, hi, desc):
        S: List[Sample] = []
        # σ-valid
        for a, b in tqdm(pairs, desc=f"{desc} sigma-valid"):
            k = math.gcd(a, b)
            s_can = f"{bin_canon(a)}/{bin_canon(b)}/{bin_canon(k)}"
            s_in = sigma_input_string(s_can, alphabet, sep=',')
            S.append(Sample(s_in, s_can))
            
        # Semantic negatives
        semnegs = int(len(pairs) * semantic_neg_ratio)
        rng = random.Random(123)
        for _ in tqdm(range(semnegs), desc=f"{desc} semantic-neg"):
            a, b = rng.choice(pairs)
            k = math.gcd(a, b)
            # Perturb the GCD (but not to the true value)
            k_bad = k + rng.choice([-2, -1, 1, 2])
            if k_bad <= 0 or k_bad == k:
                k_bad = k + 3
            s_can_bad = f"{bin_canon(a)}/{bin_canon(b)}/{bin_canon(k_bad)}"
            s_in_bad = sigma_input_string(s_can_bad, alphabet, sep=',')
            S.append(Sample(s_in_bad, s_can_bad))
            
        # Malformed
        m = int(len(pairs) * malformed_ratio)
        for _ in tqdm(range(m), desc=f"{desc} malformed"):
            L = random.randint(lo, hi)
            bad = _malformed_word_01_slash(lo, hi)
            S.append(Sample(sigma_input_string(bad, alphabet, sep=','), bad))
        random.shuffle(S)
        return S

    # Splits
    print("Generating training pairs...")
    short = sample_pairs(1, 50, max_num)
    all_short = build_from_pairs(short, 1, 50, "gcd-train")
    split = int(0.8 * len(all_short))
    train, val0 = all_short[:split], all_short[split:]

    print("Generating val1 pairs...")
    val1 = build_from_pairs(sample_pairs(51, 100, max(1, len(val0))),
                            51, 100, "gcd-val1")
    print("Generating val2 pairs...")
    val2 = build_from_pairs(sample_pairs(101, 150, max(1, len(val0))),
                            101, 150, "gcd-val2")

    _write_dataset("dataset/gcd", "", train, auto, alphabet)
    _write_dataset("dataset/gcd", "_val0", val0, auto, alphabet)
    _write_dataset("dataset/gcd", "_val1", val1, auto, alphabet)
    _write_dataset("dataset/gcd", "_val2", val2, auto, alphabet)
# ============================ DVD (sigma inputs) ==============================
def training_set_dvd(max_num: int,
                     malformed_ratio: float = 0.1,
                     semantic_neg_ratio: float = 0.1):
    """
    DVD dataset:
      - Canonical word: bin_canon(a) '/' bin_canon(b) '/' bin_canon(q),
        where q = a // b and b >= 1.
      - Train: 1..30 bits, Val0 split off.
      - Val1: 31..60 bits, Val2: 61..90 bits.
    """
    from kcm_binary import build_dvd_binary_kcm
    auto = build_dvd_binary_kcm()
    alphabet = ['0', '1', '/']

    rng = random.Random(2025)

    def rand_num_with_len(L: int) -> int:
        L = max(1, L)
        return rng.randint(1 << (L - 1), (1 << L) - 1)

    def sample_pairs(lo_len: int, hi_len: int, count: int):
        """Sample (a,b) pairs with total canonical length in [lo_len, hi_len]."""
        pairs = []
        attempts = 0
        max_attempts = count * 500  # Increase attempts for difficult ranges
        
        while len(pairs) < count and attempts < max_attempts:
            attempts += 1
            
            # For longer lengths, we need larger numbers
            if hi_len > 100:
                # For very long lengths, use larger bit sizes
                min_bits = max(1, lo_len // 3)
                max_bits = min(60, hi_len // 2)
            elif hi_len > 50:
                # Medium lengths
                min_bits = max(1, lo_len // 4)
                max_bits = min(40, hi_len // 2)
            else:
                # Short lengths
                min_bits = 1
                max_bits = min(20, hi_len // 2)
            
            mode = rng.random()
            
            if mode < 0.3:
                # exact multiple case: a = b * q
                bL = rng.randint(min_bits, max_bits)
                qL = rng.randint(min_bits, max_bits)
                b = max(1, rand_num_with_len(bL))
                q = rand_num_with_len(qL)
                a = b * q
                
                # Check if total canonical length is in range
                total_len = length_dvd_can(a, b)
                if lo_len <= total_len <= hi_len:
                    pairs.append((a, b))
                    
            elif mode < 0.6:
                # random a,b case - target specific bit lengths
                aL = rng.randint(min_bits, max_bits)
                bL = rng.randint(min_bits, max_bits)
                a = rand_num_with_len(aL)
                b = max(1, rand_num_with_len(bL))
                
                total_len = length_dvd_can(a, b)
                if lo_len <= total_len <= hi_len:
                    pairs.append((a, b))
                    
            elif mode < 0.8:
                # a < b (q=0 case) - useful for shorter results
                bL = rng.randint(min_bits, max_bits)
                b = max(2, rand_num_with_len(bL))
                a = rng.randint(1, b - 1)
                
                total_len = length_dvd_can(a, b)
                if lo_len <= total_len <= hi_len:
                    pairs.append((a, b))
                    
            else:
                # Direct targeting: pick a target length and construct a/b to match
                target_len = rng.randint(lo_len, hi_len)
                
                # Try different ways to split the target length
                for split_attempt in range(5):
                    # Length budget: target_len = len(bin_canon(a)) + 1 + len(bin_canon(b))
                    # So len(bin_canon(a)) + len(bin_canon(b)) = target_len - 1
                    remaining = target_len - 1
                    if remaining < 2:
                        continue
                        
                    a_len = rng.randint(1, remaining - 1)
                    b_len = remaining - a_len
                    
                    if a_len > 0 and b_len > 0 and a_len <= 60 and b_len <= 60:
                        a = rand_num_with_len(a_len)
                        b = max(1, rand_num_with_len(b_len))
                        
                        # Verify the length
                        actual_len = length_dvd_can(a, b)
                        if actual_len == target_len:
                            pairs.append((a, b))
                            break
        
        # Fill remaining with duplicates if needed
        while len(pairs) < count and pairs:
            pairs.append(rng.choice(pairs))
            
        return pairs[:count]

    def build_from_pairs(pairs, lo, hi, desc):
        S: List[Sample] = []

        # --- positives ---
        for a, b in tqdm(pairs, desc=f"{desc} sigma-valid"):
            q = a // b
            s_can = f"{bin_canon(a)}/{bin_canon(b)}/{bin_canon(q)}"
            s_in = sigma_input_string(s_can, alphabet, sep=',')
            S.append(Sample(s_in, s_can))

        # --- semantic negatives ---
        semnegs = int(len(pairs) * semantic_neg_ratio)
        for _ in tqdm(range(semnegs), desc=f"{desc} semantic-neg"):
            a, b = rng.choice(pairs)
            q_true = a // b
            # pick wrong q close in value or same bit-length
            if rng.random() < 0.5:
                q_bad = max(0, q_true + rng.choice([-1, 1]))
            else:
                Lq = max(1, q_true.bit_length())
                q_bad = rand_num_with_len(Lq)
                if q_bad == q_true:
                    q_bad += 1
            s_can_bad = f"{bin_canon(a)}/{bin_canon(b)}/{bin_canon(q_bad)}"
            s_in_bad = sigma_input_string(s_can_bad, alphabet, sep=',')
            S.append(Sample(s_in_bad, s_can_bad))

        # --- malformed ---
        m = int(len(pairs) * malformed_ratio)
        for _ in tqdm(range(m), desc=f"{desc} malformed"):
            bad = _malformed_word_01_slash(lo, hi)
            S.append(Sample(sigma_input_string(bad, alphabet, sep=','), bad))

        random.shuffle(S)
        return S

    # === splits ===
    train_pairs = sample_pairs(1, 50, max_num)
    all_train = build_from_pairs(train_pairs, 1, 50, "dvd-train")
    split = int(0.8 * len(all_train))
    train, val0 = all_train[:split], all_train[split:]

    val1_pairs = sample_pairs(51, 100, max(1, len(val0)))
    val2_pairs = sample_pairs(101, 150, max(1, len(val0)))
    val1 = build_from_pairs(val1_pairs, 51, 100, "dvd-val1")
    val2 = build_from_pairs(val2_pairs, 101, 150, "dvd-val2")

    _write_dataset("dataset/dvd", "", train, auto, alphabet)
    _write_dataset("dataset/dvd", "_val0", val0, auto, alphabet)
    _write_dataset("dataset/dvd", "_val1", val1, auto, alphabet)
    _write_dataset("dataset/dvd", "_val2", val2, auto, alphabet)

# ============================ PRIME (sigma inputs) ============================
def training_set_prime(max_num: int,
                       malformed_ratio: float = 0.1,
                       semantic_neg_ratio: float = 0.1):
    """
    PRIME dataset (binary encoding):
      - Canonical words over {'0','1'} with no leading zeros.
      - Train: 1..80 bits
      - Val0: holdout from train band
      - Val1: 81..100 bits
      - Val2: 101..140 bits
      - Balanced primes & composites; semantic negatives are tricky composites.
    """
    import random
    from sympy import isprime
    from kcm_binary import build_prime_binary_kcm

    auto = build_prime_binary_kcm()
    alphabet = ['0', '1']

    # ----------------------
    # Helpers
    # ----------------------
    def random_num_with_len(L: int) -> int:
        """Sample a random integer of exactly bit-length L."""
        if L <= 0:
            return 1
        return random.randint(1 << (L - 1), (1 << L) - 1)

    def sample_primes(lo: int, hi: int, count: int):
        """Sample ~count primes with bit-length in [lo,hi]."""
        xs, attempts = [], 0
        max_attempts = count * 200
        while len(xs) < count and attempts < max_attempts:
            attempts += 1
            L = random.randint(lo, hi)
            if L <= 1:
                continue
            x = random_num_with_len(L)
            if isprime(x):
                xs.append(x)
        # If not enough, just pad with random numbers from xs
        while len(xs) < count and xs:
            xs.append(random.choice(xs))
        return xs[:count]

    def sample_mix(lo: int, hi: int, count: int):
        """Return a mix of primes and composites in [lo,hi] bit-length."""
        half = count // 2
        primes = sample_primes(lo, hi, half)
        comps, attempts = [], 0
        max_attempts = count * 200
        while len(comps) < (count - half) and attempts < max_attempts:
            attempts += 1
            L = random.randint(lo, hi)
            if L <= 1:
                continue
            x = random_num_with_len(L)
            if not isprime(x):
                comps.append(x)
        while len(comps) < (count - half) and comps:
            comps.append(random.choice(comps))
        return primes + comps

    def build_from_nums(nums, lo: int, hi: int, desc: str):
        S: List[Sample] = []

        # σ-valid words
        for x in tqdm(nums, desc=f"{desc} sigma-valid"):
            s_can = bin_canon(x)
            s_in = sigma_input_string(s_can, alphabet)
            S.append(Sample(s_in, s_can))

        # Semantic negatives (deliberate composites)
        semnegs = int(len(nums) * semantic_neg_ratio)
        rng = random.Random(50)
        attempts, added = 0, 0
        max_attempts = semnegs * 50
        while added < semnegs and attempts < max_attempts:
            attempts += 1
            L = rng.randint(lo, hi)
            if L <= 1:
                continue
            top = 1 << (L - 1)
            second = 1 << (L - 2)
            tail = rng.getrandbits(max(0, L - 2))
            x_bad = top | second | (tail & ((1 << (L - 2)) - 1))
            if isprime(x_bad):
                continue  # ensure composite
            s_can_bad = bin_canon(x_bad)
            s_in_bad = sigma_input_string(s_can_bad, alphabet)
            S.append(Sample(s_in_bad, s_can_bad))
            added += 1

        # Malformed negatives
        m = int(len(nums) * malformed_ratio)
        for _ in tqdm(range(m), desc=f"{desc} malformed"):
            L = random.randint(lo, hi)
            bad = _malformed_word_01(L)
            S.append(Sample(sigma_input_string(bad, alphabet), bad))

        random.shuffle(S)
        return S

    # ----------------------
    # Splits
    # ----------------------
    short = sample_mix(1,50, max_num)
    all_short = build_from_nums(short, 1,50, "prime-train")
    split = int(0.8 * len(all_short))
    train, val0 = all_short[:split], all_short[split:]

    val1 = build_from_nums(sample_mix(51,100, max(1, len(val0))),
                           51,100, "prime-val1")
    val2 = build_from_nums(sample_mix(101,150, max(1, len(val0))),
                           101,150, "prime-val2")

    _write_dataset("dataset/prime", "", train, auto, alphabet)
    _write_dataset("dataset/prime", "_val0", val0, auto, alphabet)
    _write_dataset("dataset/prime", "_val1", val1, auto, alphabet)
    _write_dataset("dataset/prime", "_val2", val2, auto, alphabet)

# ============================ EXP (sigma inputs) ==============================
def training_set_exp(max_num: int,
                     malformed_ratio: float = 0.2,
                     semantic_neg_ratio: float = 0.2):
    """
    EXP (a^i * b^(2^i) format):
      - Canonical words: bin(a)/bin(i)/bin(b)/bin(result) where result = a^i * b^(2^i)
      - Train on manageable ranges to avoid overflow
      - Uses alphabet {'0','1','/'}
    """
    from kcm_binary import build_exp_binary_kcm
    auto = build_exp_binary_kcm()
    alphabet = ['0', '1', '/']

    rng = random.Random(2025)

    def rand_num_with_len(L: int) -> int:
        L = max(1, L)
        return rng.randint(1 << (L - 1), (1 << L) - 1)

    def sample_exp_tuples(lo_len: int, hi_len: int, count: int):
        """Sample (a,i,b) tuples with total canonical length in [lo_len, hi_len]."""
        tuples = []
        attempts = 0
        max_attempts = count * 200
        
        while len(tuples) < count and attempts < max_attempts:
            attempts += 1
            
            # Keep exponents small to avoid overflow
            i = rng.randint(0, min(4, max(0, hi_len // 10)))  # Very small exponents
            
            # Choose base values that won't cause overflow
            if hi_len <= 30:
                a_bits = rng.randint(1, min(3, hi_len // 4))
                b_bits = rng.randint(1, min(3, hi_len // 4))
            elif hi_len <= 60:
                a_bits = rng.randint(1, min(4, hi_len // 6))
                b_bits = rng.randint(1, min(4, hi_len // 6))
            else:
                a_bits = rng.randint(1, min(5, hi_len // 8))
                b_bits = rng.randint(1, min(5, hi_len // 8))
            
            a = rand_num_with_len(a_bits) if a_bits > 0 else rng.randint(0, 3)
            b = rand_num_with_len(b_bits) if b_bits > 0 else rng.randint(0, 3)
            
            # Calculate total canonical length
            total_len = length_exp_can(a, i, b)
            if total_len != float('inf') and lo_len <= total_len <= hi_len:
                tuples.append((a, i, b))
        
        # Fill remaining with duplicates if needed
        while len(tuples) < count and tuples:
            tuples.append(rng.choice(tuples))
            
        return tuples[:count]

    def build_from_tuples(tuples, lo, hi, desc):
        S: List[Sample] = []

        # --- positives ---
        for a, i, b in tqdm(tuples, desc=f"{desc} sigma-valid"):
            try:
                two_power_i = 2 ** i
                result = (a ** i) * (b ** two_power_i)
                s_can = f"{bin_canon(a)}/{bin_canon(i)}/{bin_canon(b)}/{bin_canon(result)}"
                s_in = sigma_input_string(s_can, alphabet, sep=',')
                S.append(Sample(s_in, s_can))
            except (OverflowError, ValueError):
                continue  # Skip if overflow occurs

        # --- semantic negatives ---
        semnegs = int(len(tuples) * semantic_neg_ratio)
        for _ in tqdm(range(semnegs), desc=f"{desc} semantic-neg"):
            a, i, b = rng.choice(tuples)
            try:
                two_power_i = 2 ** i
                result_true = (a ** i) * (b ** two_power_i)
                # Create wrong result
                result_bad = max(0, result_true + rng.choice([-1, 1, 2, -2]))
                if result_bad == result_true:
                    result_bad += 3
                s_can_bad = f"{bin_canon(a)}/{bin_canon(i)}/{bin_canon(b)}/{bin_canon(result_bad)}"
                s_in_bad = sigma_input_string(s_can_bad, alphabet, sep=',')
                S.append(Sample(s_in_bad, s_can_bad))
            except (OverflowError, ValueError):
                continue

        # --- malformed ---
        m = int(len(tuples) * malformed_ratio)
        for _ in tqdm(range(m), desc=f"{desc} malformed"):
            bad = _malformed_word_01_slash(lo, hi)
            S.append(Sample(sigma_input_string(bad, alphabet, sep=','), bad))

        random.shuffle(S)
        return S

    # === splits ===
    train_tuples = sample_exp_tuples(1, 50, max_num)
    all_train = build_from_tuples(train_tuples, 1, 50, "exp-train")
    split = int(0.8 * len(all_train))
    train, val0 = all_train[:split], all_train[split:]

    val1_tuples = sample_exp_tuples(51, 100, max(1, len(val0)))
    val2_tuples = sample_exp_tuples(101, 150, max(1, len(val0)))
    val1 = build_from_tuples(val1_tuples, 51, 100, "exp-val1")
    val2 = build_from_tuples(val2_tuples, 101, 150, "exp-val2")

    _write_dataset("dataset/exp", "", train, auto, alphabet)
    _write_dataset("dataset/exp", "_val0", val0, auto, alphabet)
    _write_dataset("dataset/exp", "_val1", val1, auto, alphabet)
    _write_dataset("dataset/exp", "_val2", val2, auto, alphabet)

# =============================== CLI =========================================

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Generate sigma-encoded datasets.")
    parser.add_argument("--task",
                        choices=["add","mul","gcd","dvd","prime","exp"],
                        required=True)
    parser.add_argument("--max_num", type=int, default=10000,
                        help="Number of TRAIN+VAL0 candidates in the 1..100 band (before split).")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    if args.seed:
        random.seed(args.seed)

    if args.task == "add":
        training_set_add(args.max_num)
    elif args.task == "mul":
        training_set_mul(args.max_num)
    elif args.task == "gcd":
        training_set_gcd(args.max_num)
    elif args.task == "dvd":
        training_set_dvd(args.max_num)
    elif args.task == "prime":
        training_set_prime(args.max_num)
    elif args.task == "exp":
        training_set_exp(args.max_num)
