import os
import numpy as np
import pdb
import json

num_states = 5
sequence_length = 99
num_train_sequences = 100000
num_valid_sequences = 500

all_moves = [-1, 0, 1]
# move_to_string = {0: 'O', 1: 'R', -1: 'L'}
move_to_string = {-1: 0, 0: 1, 1: 2}
state_to_string = {}
for i in range(num_states):
  state_to_string[i] = i + 3

pad_id = max(state_to_string.values()) + 1

allow_all_moves_at_boundary = False
print_state_at_end = True

rs = np.random.RandomState(0)

def generate_sequence():
  current_state = 0
  string_sequence = []
  state_sequence = []
  for _ in range(sequence_length):
    if allow_all_moves_at_boundary:
      allowed_moves = all_moves
    else:
      if current_state == 0:
        allowed_moves = [0, 1]
      elif current_state == num_states - 1:
        allowed_moves = [-1, 0]
      else:
        allowed_moves = all_moves
    move = rs.choice(allowed_moves)
    if current_state == 0 and move == -1:
      current_state = 0
    elif current_state == num_states - 1 and move == 1:
      current_state = num_states - 1
    else:
      current_state = current_state + move
    string_sequence.append(move_to_string[move])
    state_sequence.append(current_state)
  if print_state_at_end:
    string_sequence.append(state_to_string[current_state])
    state_sequence.append(current_state)
  return np.array(string_sequence), np.array(state_sequence)

num_total = num_train_sequences + num_valid_sequences
string_sequences = []
state_labels = []

for _ in range(num_total):
  string_input, state_label = generate_sequence()
  string_sequences.append(string_input)
  state_labels.append(state_label)


perm = rs.permutation(num_total)
train_data = np.array([string_sequences[i] for i in perm[:num_train_sequences]])
valid_data = np.array([string_sequences[i] for i in perm[num_train_sequences:]])
train_states = np.array([state_labels[i] for i in perm[:num_train_sequences]])
valid_states = np.array([state_labels[i] for i in perm[num_train_sequences:]])


save_name = f"emergence/nanoGPT/data/gridworld/{num_states}-states"
if not allow_all_moves_at_boundary:
  save_name += "-restrict-moves"
data_dir = os.path.join(os.path.expanduser("~"), save_name)
os.makedirs(data_dir, exist_ok=True)

train_file_name = os.path.join(data_dir, "train.bin")
arr = np.memmap(train_file_name, dtype=np.uint8, mode='w+', shape=train_data.shape)
arr[:] = train_data[:]
arr.flush()

valid_file_name = os.path.join(data_dir, "val.bin")
arr = np.memmap(valid_file_name, dtype=np.uint8, mode='w+', shape=valid_data.shape)
arr[:] = valid_data[:]
arr.flush()

# Also save a config file with: pad_id, vocab_size, num_train_tokens, num_val_tokens, seq_len
config_file_name = os.path.join(data_dir, "config.json")
config = {
  "pad_id": pad_id,
  "vocab_size": max(state_to_string.values()) + 2,
  "num_train_examples": len(train_data),
  "num_val_examples": len(valid_data),
  "seq_len": len(train_data[0]),
  "move_to_string": move_to_string,
  "state_to_string": state_to_string,
  "num_states": num_states,
  "num_state_dimensions": 1, 
}
# Save config file
with open(config_file_name, 'w') as f:
  json.dump(config, f)

train_states = train_states.reshape(-1, len(train_states[0]), 1)
valid_states = valid_states.reshape(-1, len(valid_states[0]), 1)

train_states_file_name = os.path.join(data_dir, "train_states.bin")
arr = np.memmap(train_states_file_name, dtype=np.uint8, mode='w+', shape=train_states.shape)
arr[:] = train_states[:]
arr.flush()
valid_states_file_name = os.path.join(data_dir, "val_states.bin")
arr = np.memmap(valid_states_file_name, dtype=np.uint8, mode='w+', shape=valid_states.shape)
arr[:] = valid_states[:]
arr.flush()

print(f"Saved data to {data_dir}")

