from T1001_loader import *

import json, time, os, sys, glob
import torch
import torch.nn as nn

sys.path.insert(0, '../..')
from sequence_models.graphmodel_utils import *
from sequence_models.utils import Tokenizer

# load features 
dist, omega, theta, phi, seq = loadT1001()
dist = torch.from_numpy(dist)
omega = torch.from_numpy(omega)
theta = torch.from_numpy(theta)
phi = torch.from_numpy(phi)

# process features
V = get_node_features(omega, theta, phi)
E_idx = get_k_neighbors(dist, 10)
E = get_edge_features(dist, omega, theta, phi, E_idx)
mask = get_mask(E)
E = replace_nan(E)
L = len(seq)
S = get_S_enc(seq, tokenizer)

# reshape 
V = V.view(1,140,10).float()
E = E.view(1,140,10,6).float()
E_idx = E_idx.view(1,140,10)
mask = mask.view(1,140)
S = S.view(1,140).long()
L = [140]

decoder = Struct2Seq_decoder(num_letters=20, 
            node_features=10,
            edge_features=6, 
            hidden_dim=128,
            k_neighbors=30,
            protein_features='full',
            dropout=0.10,
            use_mpnn=False)

with torch.no_grad():
    decoder.eval()
    output = decoder(V, E, E_idx, S, L,mask)