import numpy as np
import argparse
import os
from six.moves import cPickle as pkl

DUMP = "DIR/TO/CACHE/"

parser = argparse.ArgumentParser()
parser.add_argument('--v', type=int, default=20)
parser.add_argument('--t', type=int, default=3)
parser.add_argument('--n', type=int, default=20_000)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()

np.random.seed(args.seed)

# Define the vocabulary and transition matrix
vocabulary = np.stack([str(i) for i in range(args.v)])
transition_matrix = np.abs(np.random.normal(0, 1, (args.v, args.v)))
for i in range(len(transition_matrix)):
    transition_matrix[i] /= transition_matrix[i].sum()

def generate_sentence(length):
    sentence = []
    current_token = np.random.choice(vocabulary)
    for _ in range(length):
        sentence.append(current_token)
        next_token = np.random.choice(vocabulary, p=transition_matrix[int(current_token)])
        current_token = next_token
    return ' '.join(sentence)

# Generate synthetic dataset
num_samples = args.n
sentences = [generate_sentence(args.t) for _ in range(num_samples)]

with open(os.path.join(DUMP, "v={}_t={}_n={}.pkl".format(args.v, args.t, args.n)), "wb") as f:
    pkl.dump(sentences, f)