import numpy as np
import os
import argparse


def generate_multiplication_dataset(
    num_samples=10000,
    min_digits=1,
    max_digits=5,
    output_file="multiplication_dataset.txt",
    inverse=False,
    seed=None,
    output_dir=".",
):
    """
    Generates a multiplication dataset and saves it to a text file.

    Args:
        num_samples (int): The number of samples in the dataset.
        min_digits (int): The minimum number of digits for the input numbers.
        max_digits (int): The maximum number of digits for the input numbers (also used for zero-padding).
        output_file (str): The name of the output file.
        inverse (bool): If True, the output tokens are reversed.
        seed (int): The random seed.
        output_dir (str): The output directory.
    """
    if seed is not None:
        np.random.seed(seed)

    if min_digits <= 0 or max_digits <= 0:
        raise ValueError("min_digits and max_digits must be positive integers.")
    if min_digits > max_digits:
        raise ValueError("min_digits cannot be greater than max_digits.")

    # Determine the range of random numbers
    min_val = 0 if min_digits == 1 else 10 ** (min_digits - 1)
    max_val = 10**max_digits

    output_path = os.path.join(output_dir, output_file)
    with open(output_path, "w") as f:
        for _ in range(num_samples):
            # Generate two random integers
            num1 = np.random.randint(min_val, max_val)
            num2 = np.random.randint(min_val, max_val)

            # Calculate the multiplication result
            result = num1 * num2

            # Convert the input sequence to a string (zero-padded)
            num1_str = str(num1).zfill(max_digits)
            num2_str = str(num2).zfill(max_digits)

            input_seq = list(num1_str) + ["[SEP]"] + list(num2_str)
            input_str = " ".join(input_seq)

            # Convert the output sequence to a string (zero-padded)
            output_len = max_digits * 2
            result_str = str(result).zfill(output_len)
            output_seq = list(result_str)

            if inverse:
                output_seq.reverse()

            output_str = " ".join(output_seq)

            # Write to file (: delimited)
            f.write(f"{input_str} : {output_str}\n")


if __name__ == "__main__":
    # Set command line arguments
    parser = argparse.ArgumentParser(description="Generate a multiplication dataset.")
    parser.add_argument("--num_samples", type=int, default=10000, help="The number of samples in the dataset.")
    parser.add_argument("--min_digits", type=int, default=1, help="The minimum number of digits for the input numbers.")
    parser.add_argument("--max_digits", type=int, default=5, help="The maximum number of digits for the input numbers.")
    parser.add_argument(
        "--output_file",
        type=str,
        default="multiplication_dataset.txt",
        help="The name of the output file.",
    )
    parser.add_argument("--output_dir", type=str, default=".", help="The output directory.")
    parser.add_argument("--seed", type=int, default=None, help="The random seed.")
    parser.add_argument("--inverse", action="store_true", help="Reverse the output tokens.")

    args = parser.parse_args()

    # Generate dataset
    try:
        generate_multiplication_dataset(
            num_samples=args.num_samples,
            min_digits=args.min_digits,
            max_digits=args.max_digits,
            output_file=args.output_file,
            inverse=args.inverse,
            seed=args.seed,
            output_dir=args.output_dir,
        )
    except ValueError as e:
        parser.error(str(e))