import os
import torch

import numpy as np

from network_architectures import DenseCorrespondenceModel

if __name__ == '__main__':
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    print('Creating 144x256 random RGB image.')
    # create a random RGB image
    np_rgb = np.random.rand(144, 256, 3)

    # Create a random camera intrinsic matrix
    intrinsics = np.array([[256 * np.random.rand(), 0, 128.],
                           [0., 256 * np.random.rand(), 144.],
                           [0., 0., 1.]])
    # Model settings
    gripper_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'example_gripper_model.ply')
    max_depth = 1000  # in mm
    direct_regression_model_encoding_min = torch.tensor(9 * [-1.], dtype=torch.float32)
    direct_regression_model_encoding_max = torch.tensor(9 * [1.], dtype=torch.float32)

    print('Initialising the Sparse Correspondence Model with random weights.')
    # Initialise the model and move to device
    dense_correspondence_model = DenseCorrespondenceModel(
        gripper_model_path=gripper_model_path,
        intrinsic=intrinsics,
        encoder_channels=[3, 4, 8, 16, 32],
        depth_and_segmentation_decoder_channels=[32, 16, 16, 38],
        regressor_neurons=[16, 16, 16],
        direct_regression_model_encoding_min=direct_regression_model_encoding_min,
        direct_regression_model_encoding_max=direct_regression_model_encoding_max,
        dropout=0.25,
        max_depth=max_depth)

    dense_correspondence_model.to(device)
    dense_correspondence_model.eval()

    # Regress camera to end-effector pose
    camera_to_eef_pose = dense_correspondence_model.get_camera_to_eef_pose(np_rgb)
    print('Predicted camera to end-effector pose: \n', camera_to_eef_pose)
