import pickle
from promptrl.envs.alfworld_viz import AlfworldVizDataset
from transformers import GPT2TokenizerFast

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

print('Loading datasets..')
train_dataset = AlfworldVizDataset('img', [1, 2, 3, 4, 5, 6], None, tokenizer, num_samples=200, mode='train', data_mode='img')
id_dataset = AlfworldVizDataset('img', [1, 2, 3, 4, 5, 6], None, tokenizer, num_samples=float('inf'), mode='eval_in_distribution', data_mode='img')
ood_dataset = AlfworldVizDataset('img', [1, 2, 3, 4, 5, 6], None, tokenizer, num_samples=float('inf'), mode='eval_out_of_distribution', data_mode='img')

print('Train dataset size: ', len(train_dataset))
print('Eval id dataset size: ', len(id_dataset))
print('Eval ood dataset size: ', len(ood_dataset))

print('Processing..')
def untok(data):
    for row in data:
        row['goal_tok'] = row['goal']
        row['actions_tok'] = row['actions']
        row['goal'] = tokenizer.decode(row['goal'])
        row['actions'] = [tokenizer.decode(a) for a in row['actions']]
    return data
train_d = untok(train_dataset.data)
id_d = untok(id_dataset.data)
ood_d = untok(ood_dataset.data)

print('Dumping data..')
with open('train_data_viz.pkl', 'wb') as f:
    pickle.dump(train_d, f)
with open('eval_id_data_viz.pkl', 'wb') as f:
    pickle.dump(id_d, f)
with open('eval_ood_data_viz.pkl', 'wb') as f:
    pickle.dump(ood_d, f)

assert False
