from pecos.xmc.xtransformer.matcher import TransformerMatcher
import torch
import numpy as np
import torch.nn as nn
from pecos.utils import torch_util
import os
import argparse

def main():
    parser = argparse.ArgumentParser(description='Generate node features from XRTransformer')
    parser.add_argument('--dataset', type=str, default="ogbn-arxiv")
    parser.add_argument('--data_root_dir', type=str, default="./dataset")
    parser.add_argument('--save_data_dir', type=str, default="./data_for_XRTransformer")
    parser.add_argument('--max_level', type=int, required=True, help="Max level to use for generating node features from XRTransformer")
    args = parser.parse_args()
    print(args)

    # Change args.save_data_dir to args.save_data_dir/args.dataset
    args.save_data_dir = os.path.join(args.save_data_dir, args.dataset)

    # For my trained model
    model_dir = os.path.join(args.save_data_dir,'saved_models/')

    print("Loading input tensor.")
    X_tensor = torch.load(os.path.join(args.save_data_dir,'X.Tokenized.val.pt'))
    print("Input tensor loaded.")

    for i in range(args.max_level):
        level = str(i)
    #     level = '0'

        matcher = TransformerMatcher.load(model_dir+level+'.model')
        device, n_gpu = torch_util.setup_device(True)
        matcher.to_device(device,n_gpu=n_gpu)
        matcher.text_model.to(matcher.device)

        pred_params = matcher.pred_params.from_dict({"truncate_length": 128,
                                                     "batch_size": n_gpu*64,
                                                     "batch_gen_workers": 64})

        print("Start generating node features at level "+level)
        _, embedding = matcher.predict(X_tensor, pred_params=pred_params, batch_gen_workers = 64, batch_size = n_gpu*64)
        print("Got node features. Start saving node features")
        np.save('{}/Results/X.XRT.Lv{}.npy'.format(args.save_data_dir,level), embedding)
        print("Node features saved.")

if __name__ == "__main__":
    main()
