import numpy as np
import torch
from train.behavioral_cloning.models.LightCnnLstmOrigActs32 import Net


class BaseNetwork(Net):

    def __init__(self, ):
        super().__init__(camera_head=True, lstm_hidden_dim=256, lstm_layer_dim=1, action_out_dim=8, camera_out_dim=22)


class Network(BaseNetwork):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        input_dict = x
        pov = input_dict['pov']
        binary_actions = input_dict['binary_actions']
        camera_actions = input_dict['camera_actions']

        logits, camera, value = super().forward(pov, binary_actions, camera_actions)

        return {"logits": logits, "camera": camera, "value": value}


if __name__ == "__main__":
    """ main """
    in_frames = torch.from_numpy(np.random.randn(1, 10, 3, 32, 32).astype(np.float32))
    binary_actions = torch.from_numpy(np.random.randn(1, 10, 9).astype(np.float32))
    camera_actions = torch.from_numpy(np.random.randn(1, 10, 2).astype(np.float32))

    net = Network()
    out = net({"pov": in_frames, "binary_actions": binary_actions, "camera_actions": camera_actions})
    print("out_shape", out["logits"].shape, out["camera"].shape)
