import numpy as np
import argparse
from typing import List, Dict
import os


def compute_f(x: int, y: int) -> int:
    """Calculate the basic function f(x,y) = (x² + y²) mod 19"""
    r = (x**2 + y**2) % 19
    z = r - 9
    return z


def generate_sequence(length: int) -> List[int]:
    """Generate a random input sequence"""
    return np.random.randint(-9, 10, size=length).tolist()

def forward_task(X: List[int]) -> List[int]:
    """Calculate the forward task"""
    Y = []
    for i in range(len(X)):
        if i == 0:
            Y.append(X[0])
        else:
            y_i = (X[i] ** 2 + Y[i - 1] ** 2) % 19 - 9
            Y.append(y_i)
    return Y

def inverse_task(X: List[int]) -> List[int]:
    """Calculate the inverse task"""
    # Perform the same calculation as the forward task and reverse the result
    Y = forward_task(X)
    return Y[::-1]

def save_dataset(data: List[Dict], output_file: str, output_dir: str = "."):
    """Save the dataset to a file"""
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, output_file)

    with open(output_path, "w") as f:
        for sample in data:
            input_str = " ".join(map(str, sample["input"]))
            output_str = " ".join(map(str, sample["output"]))
            f.write(f"{input_str} : {output_str}\n")

def generate_dataset(num_samples: int, seq_length: int, output_dir: str, seed: int, is_test: bool = False) -> None:
    """Generate the dataset"""
    np.random.seed(seed)

    dataset = []
    for _ in range(num_samples):
        X = generate_sequence(seq_length)
        Y_forward = forward_task(X)
        dataset.append({"input": X, "output": Y_forward})

    # Normal order dataset
    file_suffix = "test" if is_test else "train"
    save_dataset(dataset, f"data.{file_suffix}", output_dir)

    # Reversed order dataset
    inverse_dataset = [{"input": d["input"], "output": d["output"][::-1]} for d in dataset]
    save_dataset(inverse_dataset, f"data-inv.{file_suffix}", output_dir)

def main():
    parser = argparse.ArgumentParser(description="Dataset generation script")
    parser.add_argument("--train_samples", type=int, default=100000, help="Number of training data samples")
    parser.add_argument("--test_samples", type=int, default=1000, help="Number of test data samples")
    parser.add_argument("--sequence_length", type=int, default=5, help="Sequence length")
    parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
    parser.add_argument("--train_seed", type=int, default=42, help="Random seed for training data")
    parser.add_argument("--test_seed", type=int, default=43, help="Random seed for test data")
    args = parser.parse_args()

    # Generate training dataset
    generate_dataset(
        num_samples=args.train_samples,
        seq_length=args.sequence_length,
        output_dir=args.output_dir,
        seed=args.train_seed,
        is_test=False,
    )

    # Generate test dataset
    generate_dataset(
        num_samples=args.test_samples,
        seq_length=args.sequence_length,
        output_dir=args.output_dir,
        seed=args.test_seed,
        is_test=True,
    )


if __name__ == "__main__":
    main()