import torch

import numpy as np


def normalize_vector(v):
    _device = v.device
    batch = v.shape[0]
    v_mag = torch.sqrt(v.pow(2).sum(1))
    v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).to(_device)))
    v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
    v = v / v_mag

    return v


def cross_product(u, v):
    batch = u.shape[0]

    i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
    j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
    k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]

    out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1)  # batch*3

    return out


def decode_extrinsic_encoding(encoding):
    numpy_input = False
    len_1_input = False

    if isinstance(encoding, np.ndarray):
        numpy_input = True
        encoding = torch.tensor(encoding)

    if len(encoding.shape) == 1:
        len_1_input = True
        encoding = encoding.unsqueeze(dim=0)

    trans = encoding[:, :3]
    rot_encoding = encoding[:, 3:]

    x_raw = rot_encoding[:, 0:3]  # batch*3
    y_raw = rot_encoding[:, 3:6]  # batch*3

    x = normalize_vector(x_raw)  # batch*3
    z = cross_product(x, y_raw)  # batch*3
    z = normalize_vector(z)  # batch*3
    y = cross_product(z, x)  # batch*3

    x = x.view(-1, 3, 1)
    y = y.view(-1, 3, 1)
    z = z.view(-1, 3, 1)
    rots = torch.cat((x, y, z), 2)  # batch*3*3

    camera_poses_in_eef = torch.zeros(len(encoding), 4, 4, dtype=torch.float32, device=encoding.device)
    camera_poses_in_eef[:, 3, 3] = 1
    camera_poses_in_eef[:, :3, :3] = rots
    camera_poses_in_eef[:, :3, 3] = trans

    if len_1_input:
        camera_poses_in_eef = camera_poses_in_eef.squeeze(0)

    if numpy_input:
        camera_poses_in_eef = camera_poses_in_eef.cpu().detach().numpy()

    return camera_poses_in_eef


def normalise_rgb(data_array):
    """
    :param data_array:Assumes a pytorch tensor
    :return:
    """
    mu = data_array.mean(dim=(2, 3), keepdim=True)
    std = data_array.std(dim=(2, 3), keepdim=True)

    return (data_array - mu) / std
