"""
Generates sequences for transformer experiments.
"""
import os
import random
import argparse

def generate_random_sequence(min_value, max_value, length, is_sorted):
    """Generate a random sequence of integers"""
    range_size = max_value - min_value + 1
    if length > range_size:
        length = range_size
        
    sequence = random.sample(range(min_value, max_value + 1), length)
    
    if is_sorted:
        sequence.sort()
    
    return sequence

def apply_permutation(input_sequence, permutation_type="reversal"):
    """Apply permutation to input sequence"""
    if permutation_type == "reversal":
        return list(reversed(input_sequence))
    elif permutation_type == "copy":
        return input_sequence.copy()
    else:
        return input_sequence.copy()

def format_sequence(original_seq, permuted_seq):
    """Format sequence with % separator"""
    return " ".join(map(str, original_seq)) + " % " + " ".join(map(str, permuted_seq)) + "\n"

def generate_datasets(args):
    """Generate training and validation datasets"""
    all_sequences = []
    
    for _ in range(args.num_sequences):
        sequence = generate_random_sequence(
            args.min_value, 
            args.max_value, 
            args.sequence_length,
            args.is_sorted
        )
        
        permuted_sequence = apply_permutation(sequence, args.permutation_type)
        formatted_sequence = format_sequence(sequence, permuted_sequence)
        all_sequences.append(formatted_sequence)
    
    random.shuffle(all_sequences)
    
    train_size = int(args.num_sequences * args.chance_in_train)
    train_sequences = all_sequences[:train_size]
    val_sequences = all_sequences[train_size:]
    
    expanded_train_sequences = []
    for sequence_item in train_sequences:
        expanded_train_sequences.extend([sequence_item] * args.num_copies)
    
    random.shuffle(expanded_train_sequences)
    
    return expanded_train_sequences, val_sequences

def write_dataset(dataset, file_path):
    """Write dataset to file"""
    with open(file_path, "w") as file:
        for item in dataset:
            file.write(item)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate sequence datasets')      
    parser.add_argument('--is_sorted', type=lambda x: (str(x).lower() == 'true'), default=True)       
    parser.add_argument('--min_value', type=int, default=0)       
    parser.add_argument('--max_value', type=int, default=100)     
    parser.add_argument('--sequence_length', type=int, default=6)
    parser.add_argument('--num_copies', type=int, default=1) 
    parser.add_argument('--num_sequences', type=int, default=10000)       
    parser.add_argument('--chance_in_train', type=float, default=0.7)  
    parser.add_argument('--permutation_type', type=str, default="reversal",
                        choices=["reversal", "copy"])
    parser.add_argument('--seed', type=int, default=42)
    
    args = parser.parse_args()
    
    random.seed(args.seed)
    
    sequence_type = "sorted" if args.is_sorted else "unsorted"
    folder_name = f"data/sequences/{sequence_type}/{args.min_value}-{args.max_value}/{args.permutation_type}"
    
    os.makedirs(folder_name, exist_ok=True)
    
    train_sequences, val_sequences = generate_datasets(args)
    
    train_file = os.path.join(folder_name, f'train_{args.num_copies}.txt')
    val_file = os.path.join(folder_name, 'test.txt')
    
    write_dataset(train_sequences, train_file)
    write_dataset(val_sequences, val_file)
    
    print(f"Generated {sequence_type} sequences with {args.permutation_type} permutation:")
    print(f"- {len(train_sequences)} training examples")
    print(f"- {len(val_sequences)} validation examples")
    print(f"Training data saved to: {train_file}")
    print(f"Validation data saved to: {val_file}")