import cv2
import copy
import torch

import numpy as np
import open3d as o3d
import torch.nn as nn

import network_building_blocks as blocks

from scipy.spatial.transform import Rotation
from utils import decode_extrinsic_encoding, normalise_rgb


class DirectRegressionModel(nn.Module):

    def __init__(self, encoder_channels, regressor_neurons, dropout=0.,
                 input_height=144, input_width=256,
                 encoding_min=torch.tensor(9 * [-1.], dtype=torch.float32),
                 encoding_max=torch.tensor(9 * [1.], dtype=torch.float32)):
        super().__init__()

        self.encoder = blocks.EncoderFlexible(channels=encoder_channels, kernels=[3] * len(encoder_channels),
                                              strides=[2] * len(encoder_channels),
                                              final_activation='relu', bias=False, batchnorm=True, final_batchnorm=True,
                                              dropout=dropout, final_dropout=dropout, residuals=False,
                                              double_conv=True)

        input_size = [input_height, input_width]
        encoding_size = [encoder_channels[-1],
                         int(np.ceil(input_size[0] / (2 ** (len(encoder_channels) - 1)))),
                         int(np.ceil(input_size[1] / (2 ** (len(encoder_channels) - 1))))]

        self.regressor_neurons = [int(np.prod(encoding_size))] + regressor_neurons + [9]

        self.regressor = blocks.MLPFlexible(self.regressor_neurons, activation='relu', final_activation='none',
                                            dropout=dropout, bias=True, batchnorm=True, final_batchnorm=False,
                                            final_dropout=False)

        self.apply(blocks.init_weights)

        self.encoding_min = encoding_min
        self.encoding_max = encoding_max

    def forward(self, x):
        features = self.encoder(x)
        features = features.contiguous().view(-1, self.regressor_neurons[0])
        normalised_camera_to_eef_pose_encoding = self.regressor(features)

        return normalised_camera_to_eef_pose_encoding

    def un_normalise(self, normalised_camera_to_eef_pose_encoding):
        min, max = self.encoding_min, self.encoding_max

        min = np.stack([min] * normalised_camera_to_eef_pose_encoding.shape[0], axis=0)
        max = np.stack([max] * normalised_camera_to_eef_pose_encoding.shape[0], axis=0)

        add = (max + min) / 2
        mul = (max - min) / 2
        mul = np.maximum((mul), 1e-8)

        if (torch.is_tensor(normalised_camera_to_eef_pose_encoding)):
            add = torch.from_numpy(add).to(normalised_camera_to_eef_pose_encoding.device)
            mul = torch.from_numpy(mul).to(normalised_camera_to_eef_pose_encoding.device)
            return mul * normalised_camera_to_eef_pose_encoding + add
        else:
            return mul * normalised_camera_to_eef_pose_encoding + add

    def get_camera_to_eef_pose(self, np_rgb):

        self.eval()

        with torch.no_grad():
            device = torch.device('cuda:0') if next(self.encoder.parameters()).is_cuda else torch.device('cpu')
            torch_rgb = torch.tensor(np_rgb, dtype=torch.float32).permute(2, 0, 1).unsqueeze(dim=0).to(device)
            normalised_rgb = normalise_rgb(torch_rgb)
            normalised_camera_to_eef_pose_encoding = self.forward(normalised_rgb)
            camera_to_eef_pose_encoding = self.un_normalise(normalised_camera_to_eef_pose_encoding)

        camera_to_eef_pose = decode_extrinsic_encoding(camera_to_eef_pose_encoding)
        camera_to_eef_pose = camera_to_eef_pose[0].detach().cpu().numpy().astype(float)

        return camera_to_eef_pose


class SparseCorrespondenceModel(nn.Module):
    def __init__(self, keypoints_3d, intrinsics, distortion_coefficients,
                 encoder_channels=[3, 4, 8, 16, 32], decoder_channels=[32, 16, 16, 38], dropout=0.25,
                 num_output_channels=1):
        super().__init__()

        self.encoder_decoder = blocks.EncoderDecoder(encoder_channels=encoder_channels,
                                                     decoder_channels=decoder_channels,
                                                     dropout=dropout,
                                                     double_conv=True,
                                                     num_output_channels=num_output_channels,
                                                     final_activation='spatialsoftargmax')
        self.keypoints_3d = keypoints_3d
        self.intrinsics = intrinsics
        self.distortion_coefficients = distortion_coefficients

    def forward(self, x):
        keypoints, _ = self.encoder_decoder(x)
        return keypoints

    def un_normalise_and_filter_keypoints(self, keypoint_pred, keypoints_3d, height, width):
        chosen_indices = np.logical_and(keypoint_pred < 0.98, keypoint_pred > -0.98)
        chosen_indices = np.logical_and(chosen_indices[:, 0], chosen_indices[:, 1])
        chosen_indices = np.squeeze(np.where(chosen_indices)[0])
        keypoints_2d = keypoint_pred[np.squeeze(chosen_indices), :]

        keypoints_2d = (keypoints_2d[:, :2] + 1) / 2. * np.array([width - 1, height - 1])
        keypoints_3d = copy.copy(keypoints_3d)[np.squeeze(chosen_indices), :]

        return keypoints_2d, keypoints_3d

    def get_camera_to_eef_pose(self, np_rgb):
        self.eval()

        with torch.no_grad():
            device = torch.device('cuda:0') if next(self.encoder_decoder.parameters()).is_cuda else torch.device('cpu')
            torch_rgb = torch.tensor(np_rgb, dtype=torch.float32).permute(2, 0, 1).unsqueeze(dim=0).to(device)
            normalised_rgb = normalise_rgb(torch_rgb)
            normalised_keypoints_2d = self.forward(normalised_rgb)

            keypoints_2d, keypoints_3d = self.un_normalise_and_filter_keypoints(
                np.squeeze(normalised_keypoints_2d.detach().cpu().numpy()),
                self.keypoints_3d,
                height=np_rgb.shape[0],
                width=np_rgb.shape[1])

            try:
                success, rotvec, pos, inliers = cv2.solvePnPRansac(keypoints_3d, keypoints_2d, self.intrinsics,
                                                                   self.distortion_coefficients)
            except Exception:
                try:
                    success, rotvec, pos = cv2.solvePnP(keypoints_3d, keypoints_2d, self.intrinsics,
                                                        self.distortion_coefficients)
                except Exception:
                    success = False

            if (success):
                rot = Rotation.from_rotvec(np.squeeze(rotvec)).as_matrix()
                predicted_eef_pose_in_cam = np.eye(4)
                predicted_eef_pose_in_cam[:3, :3] = rot
                predicted_eef_pose_in_cam[:3, 3] = np.squeeze(pos)

                camera_pose_in_eef = np.linalg.pinv(predicted_eef_pose_in_cam)

                return camera_pose_in_eef.astype(float), success

            return np.eye(4).astype(float), success


class DenseCorrespondenceModel(nn.Module):

    def __init__(self, gripper_model_path, intrinsic,
                 encoder_channels=[3, 4, 8, 16, 32], depth_and_segmentation_decoder_channels=[32, 16, 16, 38],
                 regressor_neurons=[16, 16, 16],
                 direct_regression_model_encoding_min=torch.tensor(9 * [-1.], dtype=torch.float32),
                 direct_regression_model_encoding_max=torch.tensor(9 * [1.], dtype=torch.float32), dropout=0.25,
                 max_depth=1.):
        super(DenseCorrespondenceModel, self).__init__()

        self.max_depth = max_depth
        self.intrinsic = o3d.camera.PinholeCameraIntrinsic()
        self.intrinsic.set_intrinsics(width=256, height=144, fx=intrinsic[0, 0], fy=intrinsic[1, 1], cx=intrinsic[0, 2],
                                      cy=intrinsic[1, 2])

        # load gripper point cloud
        gripper_mesh = o3d.io.read_triangle_mesh(gripper_model_path)
        self.gripper_pcd_in_gripper_frame = gripper_mesh.sample_points_uniformly(number_of_points=2 ** 10)

        self.depth_encoder_decoder = blocks.EncoderDecoder(
            encoder_channels=encoder_channels,
            decoder_channels=depth_and_segmentation_decoder_channels,
            dropout=dropout,
            double_conv=True,
            num_output_channels=1,
            final_activation='relu')

        self.segmentation_encoder_decoder = blocks.EncoderDecoder(
            encoder_channels=encoder_channels,
            decoder_channels=depth_and_segmentation_decoder_channels,
            dropout=dropout,
            double_conv=True,
            num_output_channels=1,
            final_activation='sigmoid')

        self.direct_regression_model = DirectRegressionModel(encoder_channels=encoder_channels,
                                                             regressor_neurons=regressor_neurons,
                                                             dropout=dropout,
                                                             encoding_min=direct_regression_model_encoding_min,
                                                             encoding_max=direct_regression_model_encoding_max)

    def depth_forward(self, x):
        normalised_depth, _ = self.depth_encoder_decoder(x)
        return normalised_depth

    def segmentation_forward(self, x):
        probs, _ = self.segmentation_encoder_decoder(x)
        return probs

    def un_normalise_depth(self, normalised_depth):
        depth = self.max_depth * normalised_depth
        return depth

    def get_camera_to_eef_pose(self, np_rgb, segmentation_threshold=0.5):
        self.eval()

        with torch.no_grad():
            device = torch.device('cuda:0') if next(self.depth_encoder_decoder.parameters()).is_cuda else torch.device(
                'cpu')
            torch_rgb = torch.tensor(np_rgb, dtype=torch.float32).permute(2, 0, 1).unsqueeze(dim=0).to(device)
            normalised_rgb = normalise_rgb(torch_rgb)

            normalised_depth = self.depth_forward(normalised_rgb)
            depth = self.un_normalise_depth(normalised_depth)
            depth = depth[0, 0].cpu().numpy()

            probs = self.segmentation_forward(normalised_rgb)
            mask = probs > segmentation_threshold
            mask = mask[0, 0].cpu().numpy()

            camera_to_eef_pose_init = self.direct_regression_model.get_camera_to_eef_pose(np_rgb)

            masked_depth = o3d.geometry.Image(depth * mask)
            gripper_pcd_in_camera_frame = o3d.geometry.PointCloud().create_from_depth_image(
                depth=masked_depth,
                intrinsic=self.intrinsic,
                depth_scale=1000)

            reg_p2p = o3d.pipelines.registration.registration_icp(
                source=gripper_pcd_in_camera_frame, target=self.gripper_pcd_in_gripper_frame,
                max_correspondence_distance=0.05,
                init=camera_to_eef_pose_init,
                estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(),
                criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=100))

            camera_to_eef_pose = reg_p2p.transformation

            return camera_to_eef_pose
