import os
import argparse
import numpy as np
import tensorflow_datasets as tfds
from PIL import Image
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
# install with pip install -U sentence-transformers

def decode_inst(inst):
    """Utlity to decode encoded language instruction"""
    return bytes(inst[np.where(inst != 0)].tolist()).decode("utf-8")

dataset_directories = {
        'language_table': 'gs://gresearch/robotics/language_table',
        'language_table_sim': 'gs://gresearch/robotics/language_table_sim',
        'language_table_blocktoblock_sim': 'gs://gresearch/robotics/language_table_blocktoblock_sim',
        'language_table_blocktoblock_4block_sim': 'gs://gresearch/robotics/language_table_blocktoblock_4block_sim',
        'language_table_blocktoblock_oracle_sim': 'gs://gresearch/robotics/language_table_blocktoblock_oracle_sim',
        'language_table_blocktoblockrelative_oracle_sim': 'gs://gresearch/robotics/language_table_blocktoblockrelative_oracle_sim',
        'language_table_blocktoabsolute_oracle_sim': 'gs://gresearch/robotics/language_table_blocktoabsolute_oracle_sim',
        'language_table_blocktorelative_oracle_sim': 'gs://gresearch/robotics/language_table_blocktorelative_oracle_sim',
        'language_table_separate_oracle_sim': 'gs://gresearch/robotics/language_table_separate_oracle_sim',
    }

def main(args):
    data_directory = os.path.join(dataset_directories[args.name], args.version)
    root_directory = os.path.join(args.dir, args.name)
    builder = tfds.builder_from_directory(data_directory)
    episode_ds = builder.as_dataset(split="train")
    dataset_size = len(episode_ds)
    frame_counter = []

    inst_save_path = os.path.join(root_directory, "labels")
    action_save_path = os.path.join(root_directory, "actions")
    os.makedirs(inst_save_path, exist_ok=True)
    os.makedirs(action_save_path, exist_ok=True)

    mode = 'train'
    for episode_id, episode in tqdm(enumerate(episode_ds)):
        episode_split = [args.train_size, args.train_size+args.val_size, dataset_size]

        if episode_id == episode_split[0]:
            mode = 'val'
        if episode_id == episode_split[1]:
            mode = 'test'
        if episode_id == episode_split[2]:
            break

        folder_path = os.path.join(root_directory, mode, str(episode_id))
        if os.path.isdir(folder_path): 
            #continue # to skip processed folder
            pass
        else:
            os.makedirs(folder_path)
        
        actions = []
        for step_id, step in enumerate(episode['steps'].as_numpy_iterator()):
            # image
            if args.get_image:
                frame = step['observation']['rgb']
                image = Image.fromarray(frame)
                width, height = image.size
                image = image.resize((int(width/2), int(height/2))) # resized for storage

                save_path = os.path.join(root_directory, mode, str(episode_id), f"test_{step_id+1}.png")
                image.save(save_path, "PNG")

            # instruction
            if args.get_inst:
                if step_id == 0:
                    inst = step['observation']['instruction']

                    save_path = os.path.join(inst_save_path, f"{episode_id}.npy")
                    with open(save_path, 'wb') as f:
                        np.save(f, inst)

            # action
            if args.get_action:
                actions.append(step['action'])

        if args.get_action:
            save_path = os.path.join(action_save_path, f"{episode_id}.npy")
            with open(save_path, 'wb') as f:
                np.save(f, np.array(actions))

        frame_counter.append(step_id+1)
        if step_id + 1 < args.min_length:
            fr = os.path.join(root_directory, mode, str(episode_id), f"test_{step_id+1}.png")
            for step_extend_id in range(step_id + 1, args.min_length + 1):
                to = os.path.join(root_directory, mode, str(episode_id), f"test_{step_extend_id+1}.png")
                os.system(f"cp {fr} {to}")

    print("minimum frames:", min(frame_counter))
    print("maximum frames:", max(frame_counter))
    print("average frames:", sum(frame_counter)/len(frame_counter))

    # instruction sentence embedding
    inst_directory = os.path.join(root_directory, "labels")
    sentences = [decode_inst(np.load(os.path.join(inst_directory, f"{i}.npy"))) for i in range(dataset_size)]
    model = SentenceTransformer('sentence-transformers/sentence-t5-base')
    embeddings = model.encode(sentences)
    np.save(os.path.join(inst_directory, "inst.npy"), embeddings)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--version', type=str, default='0.0.1')
    parser.add_argument('--name', type=str, default='language_table_blocktoblock_4block_sim')
    parser.add_argument('--dir', type=str, default='DATASET_PATH')
    parser.add_argument('--train_size', type=int, default=7000) # train 7000/8000
    parser.add_argument('--val_size', type=int, default=500) # val 500, test: rest(500)
    parser.add_argument('--min_length', type=int, default=6) # if shorter, copy to make it longer
    
    # flags for partial save, 
    # python langtable.py --get_image False --get_inst False
    parser.add_argument('--get_image', type=bool, default=True) 
    parser.add_argument('--get_inst', type=bool, default=True) 
    parser.add_argument('--get_action', type=bool, default=True) 

    args = parser.parse_args()
    main(args)
