import numpy as np
import os
import argparse


def generate_relu_dataset(
    num_samples=10000,
    sequence_length=10,
    output_file="relu_dataset.txt",
    inverse=False,
    seed=None,
    output_dir=".",
    permutation=None,
):
    """
    Generates a ReLU dataset and saves it to a text file.

    Args:
        num_samples (int): The number of samples in the dataset.
        sequence_length (int): The length of the sequence.
        output_file (str): The name of the output file.
        inverse (bool): If True, the output tokens are reversed (only if permutation is not specified).
        seed (int): The random seed.
        output_dir (str): The output directory.
        permutation (list[int], optional): A list of indices for permuting the output sequence. If specified, inverse is ignored.
    """
    if seed is not None:
        np.random.seed(seed)

    # Validate permutation
    if permutation is not None:
        if len(permutation) != sequence_length:
            raise ValueError(f"Permutation length ({len(permutation)}) must match sequence length ({sequence_length}).")
        if set(permutation) != set(range(sequence_length)):
            raise ValueError("Permutation must contain unique integers from 0 to sequence_length - 1.")

    output_path = os.path.join(output_dir, output_file)
    with open(output_path, "w") as f:
        for _ in range(num_samples):
            # Generate input sequence (randomly from -9 to 9)
            input_seq = np.random.randint(-9, 10, size=sequence_length)

            # Calculate output sequence
            output_seq = np.zeros(sequence_length, dtype=int)
            # g_0 = ReLU(s_0)
            output_seq[0] = max(0, input_seq[0])

            for i in range(1, sequence_length):
                # Calculate ReLU(g_{i-1} + s_i)
                output_seq[i] = max(0, output_seq[i - 1] + input_seq[i])

            # Convert input sequence to string
            input_str = " ".join(map(str, input_seq))

            # Convert output sequence to string
            if permutation is not None:
                # Sort based on permutation
                permuted_output_seq = [output_seq[i] for i in permutation]
                output_str = " ".join(map(str, permuted_output_seq))
            elif inverse:
                # Reverse the output tokens
                output_str = " ".join(map(str, output_seq[::-1]))
            else:
                output_str = " ".join(map(str, output_seq))

            # Write to file (: delimited)
            f.write(f"{input_str} : {output_str}\n")


if __name__ == "__main__":
    # Set command line arguments
    parser = argparse.ArgumentParser(description="Generate a ReLU dataset.")
    parser.add_argument("--num_samples", type=int, default=10000, help="The number of samples in the dataset.")
    parser.add_argument("--sequence_length", type=int, default=10, help="The length of the sequence.")
    parser.add_argument("--output_file", type=str, default="relu_dataset.txt", help="The name of the output file.")
    parser.add_argument("--output_dir", type=str, default=".", help="The output directory.")
    parser.add_argument("--seed", type=int, default=None, help="The random seed.")
    parser.add_argument(
        "--inverse", action="store_true", help="Reverse the output tokens (only if permutation is not specified)."
    )
    parser.add_argument(
        "--permutation",
        type=str,
        default=None,
        help='A comma-separated list of permutations specifying the order of the output sequence (e.g., "1,0,2").',
    )

    args = parser.parse_args()

    # Process permutation argument
    permutation_list = None
    if args.permutation:
        try:
            permutation_list = [int(x.strip()) for x in args.permutation.split(",")]
            # Since validation is performed after sequence_length is determined, only basic parsing is done here.
        except ValueError:
            parser.error("Permutation must be a comma-separated list of integers.")

    # Generate dataset
    try:
        generate_relu_dataset(
            num_samples=args.num_samples,
            sequence_length=args.sequence_length,
            output_file=args.output_file,
            inverse=args.inverse,
            seed=args.seed,
            output_dir=args.output_dir,
            permutation=permutation_list,
        )
    except ValueError as e:
        parser.error(str(e))  # Display validation error as argparse error