"""
Processes sequence text files into binary format for training.
"""
import os
import pickle
import numpy as np
import re
import argparse

parser = argparse.ArgumentParser(description='Process sequence dataset')  
parser.add_argument('--min_value', type=int, default=0)
parser.add_argument('--max_value', type=int, default=100)
parser.add_argument('--is_sorted', type=str, default=True)
parser.add_argument('--num_copies', type=int, default=1)
parser.add_argument('--permutation_type', type=str, default="reversal")
args = parser.parse_args()  

min_value = args.min_value
max_value = args.max_value
is_sorted = args.is_sorted
num_copies = args.num_copies
permutation_type = args.permutation_type

sequence_type = "sorted" if is_sorted == "True" else "unsorted"
base_dir = os.path.join("data", "sequences", sequence_type, f'{min_value}-{max_value}', permutation_type)
output_dir = base_dir

os.makedirs(output_dir, exist_ok=True)

train_file_path = os.path.join(base_dir, f'train_{num_copies}.txt')
val_file_path = os.path.join(base_dir, 'test.txt')

print(f"Loading training data from: {train_file_path}")
print(f"Loading validation data from: {val_file_path}")

try:
    with open(train_file_path, 'r') as f:
        train_data = f.read()
    with open(val_file_path, 'r') as f:
        val_data = f.read()
    all_data = train_data + val_data
except FileNotFoundError as e:
    print(f"Error: {e}")
    exit(1)

def find_characters(data_string):
    """Find all unique tokens in the data string"""
    pattern = r'\d+|%|\S'
    matches = re.findall(pattern, data_string)
    return set(matches)

def process_data(s, stoi, block_size):
    """Process text into tokens"""
    split_text = s.split('\n')
    ret = []
    for st in split_text:
        if st != "":
            enc_str = encode(st, stoi)
            enc_str = encode(st, stoi) + [0]
            ret += enc_str
    return ret

def get_block_size(s, stoi):
    """Calculate maximum block size needed"""
    split_text = s.split('\n')
    bs = 0
    for st in split_text:
        if st != "":
            enc_str = encode(st, stoi)
            bs = max(bs, len(enc_str))
    return bs

def encode_string(s, stonum):
    """Encode string to integers"""
    ss = s.split(" ")
    encoded_string = [stonum[ch] for ch in ss]
    return encoded_string

chars = sorted(list(find_characters(all_data)))
print("All unique characters:", ' '.join(chars))

stoi = {}
itos = {}
idx = 0

stoi['\n'] = idx
itos[idx] = '\n'
idx += 1

for ch in chars:
    if ch not in stoi:
        stoi[ch] = idx
        itos[idx] = ch
        idx += 1

vocab_size = len(stoi)
print(f"Vocabulary size: {vocab_size}")

def encode(s, stoi):
    return encode_string(s, stoi)

block_size = max(get_block_size(train_data, stoi), get_block_size(val_data, stoi))
print(f"Block size: {block_size}")

train_ids = process_data(train_data, stoi, block_size)
val_ids = process_data(val_data, stoi, block_size)

print(f"Train has {len(train_ids):,} tokens")
print(f"Val has {len(val_ids):,} tokens")

train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)

train_output = os.path.join(output_dir, f'train_{num_copies}.bin')
val_output = os.path.join(output_dir, 'val.bin')

train_ids.tofile(train_output)
val_ids.tofile(val_output)

meta = {
    'block_size': block_size,
    'vocab_size': vocab_size,
    'itos': itos,
    'stoi': stoi,
    'min_value': min_value,
    'max_value': max_value,
    'is_sorted': is_sorted,
    'permutation_type': permutation_type,
}

meta_output = os.path.join(output_dir, 'meta.pkl')

with open(meta_output, 'wb') as f:
    pickle.dump(meta, f)

print(f"Processing complete. Data saved to {output_dir}")