#!/usr/bin/env python3
"""
gen_kcm_dataset.py — generate datasets using KCM output traces.

Languages (canonical binary encodings):

  add  : bin(i)/bin(j)/bin(k) with k = i + j
  mul  : bin(i)/bin(j)/bin(k) with k = i * j
  gcd  : bin(i)/bin(j)/bin(k) with k = gcd(i, j)
  exp  : bin(i)/bin(j)/bin(k) with k = i ** j
  prime: bin(n) with n prime
  dvd  : bin(w)/bin(v) with w >= 1 and w divides v

We use your KCMs (build_*_binary_kcm) and their `output_generator`
to produce the target traces.

Files written in OUTDIR:

  input.txt,       target.txt        (TRAIN)
  input_val0.txt,  target_val0.txt   (VAL0)
  input_val1.txt,  target_val1.txt   (VAL1)
  input_val2.txt,  target_val2.txt   (VAL2)

SPLIT BY INPUT WORD LENGTH L = len(input_string) (DISJOINT):

  train:  L ∈ [1, 50]
  val0:   L ∈ [51, 100]
  val1:   L ∈ [101, 150]
  val2:   L ∈ [151, 200]

Inputs are also enforced to be UNIQUE across all splits.
"""

from __future__ import annotations
from dataclasses import dataclass
from typing import List, Callable, Tuple, Set, Optional
import os
import random
import argparse
from tqdm import tqdm

# ---------------------------------------------------------------------------
# IMPORTANT: adjust "kcm_binary" if your KCM code is in a differently named file.
# ---------------------------------------------------------------------------
from kcm_binary import (
    KCM,
    build_add_binary_kcm,
    build_mul_binary_kcm,
    build_gcd_binary_kcm,
    build_exp_binary_kcm,
    build_prime_binary_kcm,
    build_dvd_binary_kcm,
    _is_probable_prime,
)

# ============================ basic utils ====================================

@dataclass
class Sample:
    s_in: str   # canonical-binary input string
    label: int  # 0/1 membership (for sanity checking; not written to file)


def bin_canon(n: int) -> str:
    """Canonical binary string: 0 -> '0'; >0 -> no leading zeros."""
    if n < 0:
        raise ValueError(f"bin_canon only supports n >= 0, got {n}")
    return '0' if n == 0 else bin(n)[2:]


def rand_int_with_bits(bit_lo: int, bit_hi: int) -> int:
    """Sample n >= 0 whose bit_length is in [bit_lo, bit_hi]."""
    assert bit_lo >= 1 and bit_hi >= bit_lo
    b = random.randint(bit_lo, bit_hi)
    if b == 1:
        # 1-bit numbers: 0 or 1
        return random.randint(0, 1)
    # numbers in [2^(b-1), 2^b - 1]
    return random.randint(2 ** (b - 1), 2 ** b - 1)


def rand_int_with_bits_pos(bit_lo: int, bit_hi: int) -> int:
    """Sample n >= 1 with bit_length in [bit_lo, bit_hi]."""
    while True:
        n = rand_int_with_bits(bit_lo, bit_hi)
        if n > 0:
            return n


def rand_z_near(z_true: int, bit_lo: int, bit_hi: int, max_bits: int,
                valid: Callable[[int], bool]) -> int:
    """Sample an integer z != z_true s.t. valid(z) is True, bit_length in range."""
    for _ in range(1000):
        delta_mag = max(1, abs(z_true) // 4)
        delta = random.randint(1, delta_mag)
        if random.random() < 0.5:
            z = z_true + delta
        else:
            z = max(0, z_true - delta)

        if z == z_true:
            continue
        if z == 0:
            b = 1
        else:
            b = z.bit_length()
        if b > max_bits or not (bit_lo <= b <= bit_hi):
            continue
        if valid(z):
            return z

    # Fallback: random within bit range
    while True:
        z = rand_int_with_bits(bit_lo, bit_hi)
        if z != z_true and valid(z):
            return z


def within_length_range(s: str, lo: int, hi: int) -> bool:
    return lo <= len(s) <= hi

# ============================ sample generators ==============================

def gen_add_sample(bit_lo: int, bit_hi: int, max_bits: int, pos_frac: float,
                   L_lo: int, L_hi: int) -> Sample:
    """ADD: bin(i)/bin(j)/bin(k) with k = i + j."""
    while True:
        x = rand_int_with_bits(1, min(bit_hi, max_bits - 1))
        y = rand_int_with_bits(1, min(bit_hi, max_bits - 1))
        z = x + y
        bz = z.bit_length() if z > 0 else 1
        if bz > max_bits:
            continue

        if random.random() < pos_frac:
            label = 1
            z_use = z
        else:
            label = 0

            def valid(zcand: int) -> bool:
                return (x + y) != zcand

            z_use = rand_z_near(z, bit_lo, bit_hi, max_bits, valid)

        s = f"{bin_canon(x)}/{bin_canon(y)}/{bin_canon(z_use)}"
        if within_length_range(s, L_lo, L_hi):
            return Sample(s_in=s, label=label)


def gen_mul_sample(bit_lo: int, bit_hi: int, max_bits: int, pos_frac: float,
                   L_lo: int, L_hi: int) -> Sample:
    """MUL: bin(i)/bin(j)/bin(k) with k = i * j."""
    while True:
        x = rand_int_with_bits(1, min(bit_hi, max_bits // 2))
        y = rand_int_with_bits(1, min(bit_hi, max_bits // 2))
        z = x * y
        bz = z.bit_length() if z > 0 else 1
        if bz > max_bits:
            continue

        if random.random() < pos_frac:
            label = 1
            z_use = z
        else:
            label = 0

            def valid(zcand: int) -> bool:
                return (x * y) != zcand

            z_use = rand_z_near(z, bit_lo, bit_hi, max_bits, valid)

        s = f"{bin_canon(x)}/{bin_canon(y)}/{bin_canon(z_use)}"
        if within_length_range(s, L_lo, L_hi):
            return Sample(s_in=s, label=label)


def gen_gcd_sample(bit_lo: int, bit_hi: int, max_bits: int, pos_frac: float,
                   L_lo: int, L_hi: int) -> Sample:
    """GCD: bin(i)/bin(j)/bin(k) with k = gcd(i,j)."""
    import math as _math
    while True:
        x = rand_int_with_bits(1, min(bit_hi, max_bits))
        y = rand_int_with_bits(1, min(bit_hi, max_bits))
        z = _math.gcd(x, y)
        bz = z.bit_length() if z > 0 else 1
        if bz > max_bits:
            continue

        if random.random() < pos_frac:
            label = 1
            z_use = z
        else:
            label = 0

            def valid(zcand: int) -> bool:
                return _math.gcd(x, y) != zcand

            z_use = rand_z_near(z, bit_lo, bit_hi, max_bits, valid)

        s = f"{bin_canon(x)}/{bin_canon(y)}/{bin_canon(z_use)}"
        if within_length_range(s, L_lo, L_hi):
            return Sample(s_in=s, label=label)


def gen_exp_sample(bit_lo: int, bit_hi: int, max_bits: int, pos_frac: float,
                   L_lo: int, L_hi: int) -> Sample:
    """EXP: bin(i)/bin(j)/bin(k) with k = i ** j."""
    while True:
        # Keep exponent small to avoid gigantic integers
        x = rand_int_with_bits(1, min(16, max_bits // 8))
        y = random.randint(0, 16)
        z = pow(x, y)
        bz = z.bit_length() if z > 0 else 1
        if bz > max_bits:
            continue

        if random.random() < pos_frac:
            label = 1
            z_use = z
        else:
            label = 0

            def valid(zcand: int) -> bool:
                return pow(x, y) != zcand

            z_use = rand_z_near(z, bit_lo, bit_hi, max_bits, valid)

        s = f"{bin_canon(x)}/{bin_canon(y)}/{bin_canon(z_use)}"
        if within_length_range(s, L_lo, L_hi):
            return Sample(s_in=s, label=label)


def gen_prime_sample(bit_lo: int, bit_hi: int, max_bits: int, pos_frac: float,
                     L_lo: int, L_hi: int) -> Sample:
    """PRIME: bin(n) with n prime or not."""
    assert bit_hi <= max_bits
    while True:
        if random.random() < pos_frac:
            # positive: n is prime
            n = rand_int_with_bits(1, bit_hi)  # bit_lo ignored; length filtered below
            if not _is_probable_prime(n):
                continue
            label = 1
        else:
            # negative: n is not prime
            n = rand_int_with_bits(1, bit_hi)
            if _is_probable_prime(n):
                continue
            label = 0

        s = bin_canon(n)
        if within_length_range(s, L_lo, L_hi):
            return Sample(s_in=s, label=label)


def gen_dvd_sample(bit_lo: int, bit_hi: int, max_bits: int, pos_frac: float,
                   L_lo: int, L_hi: int) -> Sample:
    """
    DVD: bin(w)/bin(v) with w >= 1 and w divides v.

    We still loosely control numeric sizes with max_bits, but the *real*
    split buckets are by |word| = len("bin(w)/bin(v)").
    """
    while True:
        is_pos = (random.random() < pos_frac)

        if is_pos:
            # Positive: choose w >= 1, t >= 0, v = w * t
            w = rand_int_with_bits_pos(1, min(bit_hi, max_bits))   # w >= 1
            t = rand_int_with_bits(1, min(bit_hi, max_bits))       # t >= 0 allowed
            v = w * t
            label = 1
        else:
            # Negative: choose w >= 1 and v such that w does NOT divide v
            w = rand_int_with_bits_pos(1, min(bit_hi, max_bits))   # w >= 1
            v = rand_int_with_bits_pos(1, min(bit_hi, max_bits))   # v >= 1
            if v % w == 0:
                v += 1
                if v % w == 0:
                    continue
            label = 0

        s = f"{bin_canon(w)}/{bin_canon(v)}"
        if within_length_range(s, L_lo, L_hi):
            return Sample(s_in=s, label=label)

# ============================ KCM + alphabet mapping =========================

def get_kcm_and_alphabet(language: str) -> Tuple[KCM, List[str]]:
    if language == "prime":
        return build_prime_binary_kcm(), ['0', '1']
    elif language == "add":
        return build_add_binary_kcm(), ['0', '1', '/']
    elif language == "mul":
        return build_mul_binary_kcm(), ['0', '1', '/']
    elif language == "gcd":
        return build_gcd_binary_kcm(), ['0', '1', '/']
    elif language == "exp":
        return build_exp_binary_kcm(), ['0', '1', '/']
    elif language == "dvd":
        return build_dvd_binary_kcm(), ['0', '1', '/']
    else:
        raise ValueError(f"Unknown language: {language}")

# ============================ dataset writing ================================

def _write_dataset(dirpath: str, split_name: str,
                   samples: List[Sample],
                   kcm: KCM,
                   alphabet: List[str]) -> None:
    os.makedirs(dirpath, exist_ok=True)

    if split_name == "train":
        fin = os.path.join(dirpath, "input.txt")
        ftg = os.path.join(dirpath, "target.txt")
    else:
        fin = os.path.join(dirpath, f"input_{split_name}.txt")
        ftg = os.path.join(dirpath, f"target_{split_name}.txt")

    # Write with a progress bar so long writes are visible.
    with open(fin, "w") as fi, open(ftg, "w") as ft:
        for idx, sm in enumerate(tqdm(samples, desc=f"Writing {split_name}") if samples else []):
            w = sm.s_in
            out = kcm.output_generator(w, alphabet)  # your full trace
            fi.write(w + "\n")
            ft.write(out + "\n")
            # flush periodically to ensure data hits disk and to make
            # external tools (like `ls`/`du`/`scp`) see progress.
            if (idx + 1) % 100 == 0:
                fi.flush()
                ft.flush()
                os.fsync(fi.fileno())
                os.fsync(ft.fileno())

# ============================ split driver ===================================

def generate_split(language: str, n_samples: int,
                   bit_lo: int, bit_hi: int, max_bits: int,
                   pos_frac: float,
                   L_lo: int, L_hi: int,
                   seen: Optional[Set[str]] = None) -> List[Sample]:
    """
    Generate n_samples with given length range and update `seen` with inputs.
    If `seen` is provided, we enforce global uniqueness of s_in.
    """
    if n_samples <= 0:
        return []

    if seen is None:
        seen_local: Set[str] = set()
    else:
        seen_local = seen

    gens = {
        "add":   gen_add_sample,
        "mul":   gen_mul_sample,
        "gcd":   gen_gcd_sample,
        "exp":   gen_exp_sample,
        "prime": gen_prime_sample,
        "dvd":   gen_dvd_sample,
    }
    if language not in gens:
        raise ValueError(f"Unsupported language: {language}")

    gen_fn = gens[language]
    desc = f"Generating len∈[{L_lo},{L_hi}] for {language}"
    samples: List[Sample] = []

    # keep sampling until we have n_samples UNIQUE inputs
    with tqdm(total=n_samples, desc=desc) as pbar:
        while len(samples) < n_samples:
            sm = gen_fn(bit_lo, bit_hi, max_bits, pos_frac, L_lo, L_hi)
            if sm.s_in in seen_local:
                continue
            seen_local.add(sm.s_in)
            samples.append(sm)
            pbar.update(1)

    return samples
 


# ============================ main ===========================================

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--language", type=str, required=True,
                    choices=["add", "mul", "gcd", "exp", "prime", "dvd"])
    ap.add_argument("--outdir", type=str, required=True)
    ap.add_argument("--n_train", type=int, default=2000)
    ap.add_argument("--n_val0", type=int, default=200)
    ap.add_argument("--n_val1", type=int, default=200)
    ap.add_argument("--n_val2", type=int, default=200)
    ap.add_argument("--max-bits", type=int, default=300,
                    help="Upper bound on bit-length of underlying integers.")
    ap.add_argument("--pos-frac", type=float, default=0.5,
                    help="Approximate fraction of positive samples.")
    ap.add_argument("--seed", type=int, default=42)
    args = ap.parse_args()

    random.seed(args.seed)

    max_bits = args.max_bits
    if max_bits < 3:
        raise ValueError("max-bits must be at least 3")

    # Build the KCM + alphabet once
    kcm, alphabet = get_kcm_and_alphabet(args.language)

    # TRAIN: input length 1-100
    print("[INFO] Generating TRAIN (len 1-100)...")
    train_samples = generate_split(
        args.language, args.n_train,
        bit_lo=1, bit_hi=max_bits, max_bits=max_bits,
        pos_frac=args.pos_frac,
        L_lo=1, L_hi=50,
    )
    _write_dataset(args.outdir, "train", train_samples, kcm, alphabet)

    # VAL0: input length 1-100
    print("[INFO] Generating VAL0 (len 1-100)...")
    val0_samples = generate_split(
        args.language, args.n_val0,
        bit_lo=1, bit_hi=max_bits, max_bits=max_bits,
        pos_frac=args.pos_frac,
        L_lo=1, L_hi=50,
    )
    _write_dataset(args.outdir, "val0", val0_samples, kcm, alphabet)

    # VAL1: input length 101-200
    print("[INFO] Generating VAL1 (len 101-200)...")
    val1_samples = generate_split(
        args.language, args.n_val1,
        bit_lo=1, bit_hi=max_bits, max_bits=max_bits,
        pos_frac=args.pos_frac,
        L_lo=51, L_hi=100,
    )
    _write_dataset(args.outdir, "val1", val1_samples, kcm, alphabet)

    # VAL2: input length 201-300
    print("[INFO] Generating VAL2 (len 201-300)...")
    val2_samples = generate_split(
        args.language, args.n_val2,
        bit_lo=1, bit_hi=max_bits, max_bits=max_bits,
        pos_frac=args.pos_frac,
        L_lo=101, L_hi=150,
    )
    _write_dataset(args.outdir, "val2", val2_samples, kcm, alphabet)


if __name__ == "__main__":
    main()
