import torch

import numpy as np

from network_architectures import DirectRegressionModel

if __name__ == '__main__':
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    print('Creating 144x256 random RGB image.')
    # create a random RGB image
    np_rgb = np.random.rand(144, 256, 3)

    # Encoding_min and encoding_max are used for normalisation
    encoding_min = torch.tensor(9 * [-1.], dtype=torch.float32)
    encoding_max = torch.tensor(9 * [1.], dtype=torch.float32)

    print('Initialising the Direct Regression Model with random weights.')
    # Initialise the model and move to device
    direct_regression_model = DirectRegressionModel(encoder_channels=[3, 4, 8, 16, 32],
                                                    regressor_neurons=[16, 16, 16],
                                                    dropout=0.25,
                                                    encoding_min=encoding_min,
                                                    encoding_max=encoding_max)
    direct_regression_model.to(device)

    # Regress camera to end-effector pose
    camera_to_eef_pose = direct_regression_model.get_camera_to_eef_pose(np_rgb)
    print('Predicted camera to end-effector pose: \n', camera_to_eef_pose)
