import numpy as np
import torch
from train.behavioral_cloning.models.OrigActs import LstmOrigActs


class Network(LstmOrigActs):
    def __init__(self):
        super().__init__(num_actions=8, num_camera=18)


if __name__ == "__main__":
    """ main """
    in_frames = torch.from_numpy(np.random.randn(2, 10, 3, 64, 64).astype(np.float32))
    binary_actions = torch.from_numpy(np.random.randn(1, 10, 9).astype(np.float32))
    camera_actions = torch.from_numpy(np.random.randn(1, 10, 2).astype(np.float32))
    
    net = Network()
    out = net({"pov": in_frames, "binary_actions": binary_actions, "camera_actions": camera_actions})
    print("out_shape", out["logits"].shape, out["camera"].shape)
