import torch

import numpy as np

from network_architectures import SparseCorrespondenceModel

if __name__ == '__main__':
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    print('Creating 144x256 random RGB image.')
    # create a random RGB image
    np_rgb = np.random.rand(144, 256, 3)

    # Create random keypoint locations in the end-effector's frame
    keypoints_3d = np.random.rand(38, 3)

    # Create a random camera intrinsic matrix
    intrinsics = np.array([[256 * np.random.rand(), 0, 128.],
                           [0., 256 * np.random.rand(), 144.],
                           [0., 0., 1.]])

    # Set all distortion coefficients to zero as this is an example
    distortion_coefficients = np.zeros(4)

    print('Initialising the Sparse Correspondence Model with random weights.')
    # Initialise the model and move to device
    sparse_correspondence_model = SparseCorrespondenceModel(keypoints_3d=keypoints_3d,
                                                            intrinsics=intrinsics,
                                                            distortion_coefficients=distortion_coefficients,
                                                            encoder_channels=[3, 4, 8, 16, 32],
                                                            decoder_channels=[32, 16, 16, 38],
                                                            dropout=0.25,
                                                            num_output_channels=keypoints_3d.shape[0])
    sparse_correspondence_model.to(device)
    sparse_correspondence_model.eval()

    # Regress camera to end-effector pose
    camera_to_eef_pose, success = sparse_correspondence_model.get_camera_to_eef_pose(np_rgb)
    print('Predicted camera to end-effector pose: \n', camera_to_eef_pose)
