import argparse
import os
import random

# numpy and pickle are not needed


def generate_self_ref_index_sample(n, m):
    """
    Generates a single sample for the self-referential index task.
    m: The number of recent y values to use for index calculation.
    """
    # x = [random.randint(0, n - 1) for _ in range(n)]
    x = [random.randint(1, n - 1) for _ in range(n)]
    y = []
    # current_sum_y = 0 # Removed as it is no longer needed
    for i in range(n):
        if i == 0:
            # Assuming the formula y_1 = x[x_1 mod n] is correct.
            idx = x[0] % n
        else:
            # Calculate the sum of the last m y's (or all of them if there are fewer than i)
            start_index = max(0, i - m)
            current_sum_y = sum(y[start_index:i])
            idx = current_sum_y % n
        yi = x[idx]
        y.append(yi)
        # current_sum_y += yi # Removed as it is no longer needed
    # Return as list of ints
    return list(x), list(y)


def save_combined_sequences(lines, filepath):
    """Writes a list of lines to the specified file path."""
    with open(filepath, "w") as f:
        for line in lines:
            f.write(line + "\n")
    print(f"Dataset saved to {filepath}")


def generate_split_dataset(n, num_samples, seed, output_dir, split_name, m):
    """
    Generates a dataset for the specified split (train or test) and
    saves it as a text file in the "X : Y" format.
    output_dir: Full path to the output directory.
    """
    random.seed(seed)

    # Create output directory for each n # Removed as it is created on the shell side
    # output_dir_n = os.path.join(base_output_dir, f"n={n}")
    # os.makedirs(output_dir_n, exist_ok=True)

    print(f"Generating {split_name} data for n={n}, m={m} " f"into {output_dir}...")

    data_lines = []
    data_inv_lines = []

    for _ in range(num_samples):
        x, y = generate_self_ref_index_sample(n, m)
        x_str = " ".join(map(str, x))
        y_str = " ".join(map(str, y))
        y_inv_str = " ".join(map(str, y[::-1]))  # Reversed Y

        data_lines.append(f"{x_str} : {y_str}")
        data_inv_lines.append(f"{x_str} : {y_inv_str}")

    # Define file path (extension to .train/.test)
    path_data = os.path.join(output_dir, f"data.{split_name}")
    path_data_inv = os.path.join(output_dir, f"data-inv.{split_name}")  # Use output_dir directly

    # Save to file
    save_combined_sequences(data_lines, path_data)
    save_combined_sequences(data_inv_lines, path_data_inv)


# def generate_self_ref_index_task(num_samples, n, m):
#     """
#     Generates a specified number of samples for the self-referential index task.
#     num_samples: Number of samples to generate
#     n: Length of the list
#     m: Number of recent y values to use for index calculation
#     """
#     for _ in range(num_samples):
#         x, y = generate_self_ref_index_sample(n, m)
#         x_str = " ".join(map(str, x))
#         y_str = " ".join(map(str, y))
#         y_inv_str = " ".join(map(str, y[::-1]))  # Reversed Y
#
#         print(f"{x_str} : {y_str}")
#         print(f"{x_str} : {y_inv_str}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate Self-Ref Index Task Datasets (train/test splits).")
    parser.add_argument("--n", type=int, required=True, help="Sequence length.")
    parser.add_argument(
        "--m",
        type=int,
        required=True,
        help="Number of previous y values to sum for index.",
    )
    parser.add_argument(
        "--num_train_samples",
        type=int,
        default=100000,
        help="Number of train samples. (%(default)s)",
    )
    parser.add_argument(
        "--num_test_samples",
        type=int,
        default=1000,
        help="Number of test samples. (%(default)s)",
    )
    parser.add_argument(
        "--seed_train",
        type=int,
        default=42,
        help="Random seed for train. (%(default)s)",
    )
    parser.add_argument(
        "--seed_test",
        type=int,
        default=43,
        help="Random seed for test. (%(default)s)",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,  # Changed to required
        help="Full output directory path (e.g., dataset/index/n=31_m=2).",
    )
    args = parser.parse_args()

    # Generate training data
    generate_split_dataset(
        n=args.n,
        m=args.m,
        num_samples=args.num_train_samples,
        seed=args.seed_train,
        output_dir=args.output_dir,  # Changed base_output_dir to output_dir
        split_name="train",
    )

    # Generate test data
    generate_split_dataset(
        n=args.n,
        m=args.m,
        num_samples=args.num_test_samples,
        seed=args.seed_test,
        output_dir=args.output_dir,  # Changed base_output_dir to output_dir
        split_name="test",
    )

    print(f"\nDataset generation complete for n={args.n}, m={args.m} " f"in {args.output_dir}.")