"""
Sample from a trained model
"""

import torch
from myModel import *
from preprocess import *

# load text file as a character string; build char-level vocabulary, encoder and decoder

#with open('/tmp/input.txt', 'r') as f:
#    data = f.read()
#print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
#chars = sorted(list(set(data)))
#vocab_size = len(chars)
#print("all the unique characters:", ''.join(chars))
#print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
#stoi = { ch:i for i,ch in enumerate(chars) }
#itos = { i:ch for i,ch in enumerate(chars) }
#def encode(s):
#    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
#def decode(l):
#    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


# start to generate outputs

start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = None #200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

model = torch.load('/tmp/my_model')

model.eval()
model.to(device)

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

for k in range(num_samples):
  y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
  print(decode(y[0].tolist()))
  print('---------------')
