import torch


def euclidean_distance(x, y, squared=True):
    """Returns squared Euclidean distance pairwise between x and y.

    If x is size (N, D) and y is size (M, D), returns a matrix of size (N, M)
    where element (i, j) is the distance between x[i, :] and y[j, :]. If x is
    size (N, D) and y is size (D,), returns a vector of length (N,) where each
    element is the distance from a row in x to y. If both have size (D,)
    returns a number.
    """
    diffs = (x.unsqueeze(-2) - y).squeeze()
    dist = torch.sum(diffs**2, dim=-1)
    if not squared:
        dist = torch.sqrt(dist)
    return dist


def euler_rotation_distance(x, y, squared=True):
    if x.size(0) > 1 and y.size(0) > 1:
        # Perform a batch operation. TODO: Handle this better - separate out batch operation?
        x = x.unsqueeze(-3)

    R_y_t = y.transpose(-2, -1)  # w_R_g -> g_R_w
    R_y_x = R_y_t @ x  # Transform x points into y frame.
    rot_err = torch.eye(3, device=x.device, dtype=x.dtype) - R_y_x
    rot_err = torch.sum(rot_err**2, dim=(-2, -1))

    if not squared:
        rot_err = torch.sqrt(rot_err)
    # rot_err = torch.linalg.norm(rot_err, dim=(-2, -1))
    return rot_err


def euclidean_path_length(path):
    """Calculates the length along a path. Path is a tensor with dimension (T, D)."""
    diff = path[:-1, :] - path[1:, :]
    dist = (diff ** 2).sum(-1)
    return torch.sqrt(dist).sum()


class PoseDistance(object):
    def __init__(self, rot_scale=1.0, squared=True):
        self.rot_scale = rot_scale  # Scales rotation component.
        self.squared = squared

    def __call__(self, x, y):
        """Calculates pose distance between two poses. Inputs are the position
        vector, concatenated with the flattened rotation matrix, size (..., 12)."""
        x_pos, x_rot = x[..., :3], x[..., 3:].view(*x.size()[:-1], 3, 3)
        y_pos, y_rot = y[..., :3], y[..., 3:].view(*y.size()[:-1], 3, 3)
        pos_dist = euclidean_distance(x_pos, y_pos, squared=self.squared)
        rot_dist = euler_rotation_distance(x_rot, y_rot, squared=self.squared)

        dist = pos_dist + self.rot_scale * rot_dist

        return dist
