"""

Formats:
    Quaternion: 4 real numbers saved in this order: (r, i, j, k).
    Rotation-Translation: unit quaternion with non-negative real part represents the rotation. The translation is represented as an element of R^{3}.
    Dual quaternion: 8 real numbers representing two quaternions, real part and dual part concatenated in this order. Quaternions are saved in this order: (r, i, j, k).
    DQmat: Quaternions are represented as subalgebra of M2(C). Dual numbers are realized as the subalgebra of M2(R) generated by I_{2} and the nilpotent matrix with all zeros except at upper right corners. DQmat represents a dual quaternion as the sum of the tensor product of the real part quaternion with I_{2} added to the tensor product of the nilpotent matrix with the dual part.
    Matrix: The format used by Arrigoni et al, 2016.

    A matrix of blocks is an array with shape (n, m, s, s). The corresponding block matrix is an array with shape (n*s, m*s) such that the (i, j) block is mb[i, j, :, :].
"""

import csv
from datetime import datetime

import numpy as np
import scipy.stats as st
import scipy.linalg as sla
import scipy.sparse.linalg as ssl


##################################################################################
#### Conversion: Matrix of blocks <-> Block matrix
##################################################################################
def mb2bm(mb, block_size = 4):
    """
    Convert a matrix of blocks to a block matrix.

    """
    return mb.swapaxes(1, 2).reshape(mb.shape[0]*block_size, mb.shape[1]*block_size)


def bm2mb(bm, block_size = 4):
    """
    Convert a block matrix to a matrix of blocks.
    """
    return bm.reshape(bm.shape[0]//block_size, block_size, bm.shape[1]//block_size, block_size).swapaxes(1, 2)



##################################################################################
#### Conversion: Dual quaternion <-> DQmat
##################################################################################
# Constants
DQ_REAL_CONST = np.eye(2, dtype=np.cdouble)
DQ_DUAL_CONST = np.array([[0, 1], \
                          [0, 0]], \
                         dtype=np.cdouble)

def q2m2(q):
    """
    Convert the array of quaternions to M2(C).

    q is an array. Last dimension of q.shape is of size 4.
    Returns complex 2x2 matrices representing the quaternions.
    """
    shape = (*q.shape[:-1], 2, 2)
    out = np.zeros(shape, dtype=np.cdouble)

    # Diagonal elements
    out[..., 0, 0] = q[..., 0] + 1j*q[..., 1]
    out[..., 1, 1] = out[..., 0, 0].conj()

    # Off-diagonal elements
    out[..., 0, 1] = q[..., 2] + 1j*q[..., 3]
    out[..., 1, 0] = -out[..., 0, 1].conj()

    return out

def m22q(m2):
    """
    Convert an array of quaternion represented in M2(C) to quaternions.
    m2 is a cdouble array of shape (..., 2, 2).
    Returns an array of shape (..., 4).
    """
    shape = (*m2.shape[:-2], 4)
    out = np.zeros(shape, dtype=np.double)

    # r element
    out[..., 0] = m2[..., 0, 0].real
    # i element
    out[..., 1] = m2[..., 0, 0].imag
    # j element
    out[..., 2] = m2[..., 0, 1].real
    out[..., 3] = m2[..., 0, 1].imag

    return out

def batchedKron(batch, mat):
    """
    Calculate the Kronecker product of a batch of matrices of shape (..., n, m) with a given matrix mat of shape (n1, m1).

    Result is:
        result[..., k1, k2] = batch(..., i1, i2) * mat(j1, j2)
    with
        k1 = i1*n1 + j1
        k2 = i2*m1 + j2
    """
    matshape = mat.shape
    mat = np.tile(mat, batch.shape[-2::])
    batch = np.repeat(batch, matshape[0], axis=-2)
    batch = np.repeat(batch, matshape[1], axis=-1)

    return batch * mat

def _dqreal(dq):
    """
    Obtain the real part of the dual quaterion dq. dq has shape (..., 8). The first four coordinates of the last dimension are the real part.
    """
    return dq[..., 0:4]

def _dqdual(dq):
    """
    Obtain the dual part of the dual quaterion dq. dq has shape (..., 8). The last four coordinates of the last dimension are the real part.
    """
    return dq[..., 4::]

def dq2dqmat(dq):
    """
    Convert dual quaternions to DQmat format. dq has shape (..., 8) with the first
     and last four real numbers of the last dimension representing the real and
     dual parts of the dual quaternion, respectively.
    """
    # Handle the real part
    m2real = q2m2(_dqreal(dq))
    dqmatreal = batchedKron(m2real, DQ_REAL_CONST)

    # Handle the dual part
    m2dual = q2m2(_dqdual(dq))
    dqmatdual = batchedKron(m2dual, DQ_DUAL_CONST)

    return dqmatreal + dqmatdual

def _innerProduct(x, y):
    """
    Frobenius inner product of matrix elements.
    """
    return np.sum(np.diagonal(x @ y.conj().T, axis1=-2, axis2=-1), axis=-1)

def _tenbackproj(dqmat, basisel):
    """
    Obtain the tensor product coordinate with respect to basis element basisel.
    dqmat has shape (..., 4, 4).
    basisel has shape (4, 4).
    """
    return _innerProduct(dqmat, basisel)/_innerProduct(basisel, basisel)

def _dqmat2dq_util(dqmat, basisel):
    """
    Utility function. Given basis element of dual numbers matrix represetnation (either DQ_REAL_CONST or DQ_DUAL_CONST), return its quaternion coordinate.
    """
    out = np.zeros((*dqmat.shape[:-2], 4), dtype=np.double)

    # r and i coordinates
    be = np.kron(np.array([[1, 0], \
                           [0, 0]], dtype=np.cdouble), \
                 basisel)
    tmp = _tenbackproj(dqmat, be)
    out[..., 0] = tmp.real
    out[..., 1] = tmp.imag

    # j and k coordinates
    be = np.kron(np.array([[0, 1], \
                           [0, 0]], dtype=np.cdouble), \
                 basisel)
    tmp = _tenbackproj(dqmat, be)
    out[..., 2] = tmp.real
    out[..., 3] = tmp.imag

    return out

def dqmat2dq(dqmat):
    """
    Conver the DQmat to dual quaternion. dqmat has shape (..., 4, 4).
    """
    realpart = _dqmat2dq_util(dqmat, DQ_REAL_CONST)
    dualpart = _dqmat2dq_util(dqmat, DQ_DUAL_CONST)

    return np.concatenate([realpart, dualpart], axis=-1)



##################################################################################
#### Conversion: Dual quaternion <-> Rotation-Translation
##################################################################################
def dq2rt(dq):
    """
    Convert dual quaternion to rotation-translation representation.
    dq has shape (..., 8).
    Returns r, t.
    r has shape (..., 4). Contains the rotations.
    t has shape (..., 3). Contains the translations.
    """
    r = dq[..., 0:4]
    rmat = q2m2(r)

    t = dq[..., 4::]
    tmat = q2m2(t)
    perm = np.arange(tmat.ndim)
    perm[[-1, -2]] = perm[[-2, -1]]
    tmat = tmat @ rmat.conj().transpose(perm)
    t = m22q(tmat)
    t = t[..., 1::]

    return r, t

def rt2dq(r, t):
    """
    Convert rotation-translation representation to dual quaternion.
    r has shape (..., 4). Contains the rotations.
    t has shape (..., 3). Contains the translations.
    Returns dq, which has shape (..., 8).
    """
    shape = (*r.shape[:-1], 8)
    dq = np.zeros(shape, dtype=np.double)

    dq[..., 0:4] = r
    rmat = q2m2(r)
    shape = (*r.shape[:-1], 4)
    tmat = np.zeros(shape, dtype=np.double)
    tmat[..., 1::] = t
    tmat = q2m2(tmat)
    tmat = tmat @ rmat
    dq[..., 4::] = m22q(tmat)


    return dq



##################################################################################
#### Conversion: Rotation-Translation <-> Matrix
##################################################################################
def r2om(r):
    """
    Convert rotation (in quaterion format) to 3x3 rotation matrices.
    r has shape (..., 4).
    Output has shape (..., 3, 3).
    """
    shape = (*r.shape[:-1], 3, 3)
    out = np.zeros(shape, dtype=np.double)

    # r, i, j, k
    # 0, 1, 2, 3

    # First row
    out[..., 0, 0] = 1 - 2*(r[..., 2]**2 + r[..., 3]**2)
    out[..., 0, 1] = 2*(r[..., 1]*r[..., 2] - r[..., 3]*r[..., 0])
    out[..., 0, 2] = 2*(r[..., 1]*r[..., 3] + r[..., 2]*r[..., 0])

    # Second row
    out[..., 1, 0] = 2*(r[..., 1]*r[..., 2] + r[..., 3]*r[..., 0])
    out[..., 1, 1] = 1 - 2*(r[..., 1]**2 + r[..., 3]**2)
    out[..., 1, 2] = 2*(r[..., 2]*r[..., 3] - r[..., 1]*r[..., 0])

    # Third row
    out[..., 2, 0] = 2*(r[..., 1]*r[..., 3] - r[..., 2]*r[..., 0])
    out[..., 2, 1] = 2*(r[..., 2]*r[..., 3] + r[..., 1]*r[..., 0])
    out[..., 2, 2] = 1 - 2*(r[..., 1]**2 + r[..., 2]**2)

    return out

def om2r(om):
    """
    Convert 3x3 rotation matrices to rotations in quaterion format.
    om has shape (..., 3, 3).
    Output has shape (..., 4).
    """
    shape = (*om.shape[:-2], 4)
    out = np.zeros(shape, dtype=np.double)

    out[..., 0] = 0.5*np.sqrt(1 + om[..., 0, 0] + om[..., 1, 1] + om[..., 2, 2])
    out[..., 1] = 0.25*(om[..., 2, 1] - om[..., 1, 2])/out[..., 0]
    out[..., 2] = 0.25*(om[..., 0, 2] - om[..., 2, 0])/out[..., 0]
    out[..., 3] = 0.25*(om[..., 1, 0] - om[..., 0, 1])/out[..., 0]

    #out[..., 1] = 0.5*np.sqrt(1 + om[..., 0, 0] - om[..., 1, 1] - om[..., 2, 2])
    #out[..., 2] = 0.25*(om[..., 0, 1] + om[..., 1, 0])/out[..., 1]
    #out[..., 3] = 0.25*(om[..., 0, 2] + om[..., 2, 0])/out[..., 1]
    #out[..., 0] = 0.25*(om[..., 2, 1] - om[..., 1, 2])/out[..., 1]

    return out

def rt2mat(r, t):
    """
    Convert rotation-translation representation to matrix representation.
    r has shape (..., 4).
    t has shape (..., 3).
    Output has shape (..., 4, 4).
    """
    shape = (*r.shape[:-1], 4, 4)
    out = np.zeros(shape, dtype=np.double)

    # One in the bottom left corner
    out[..., -1, -1] = 1

    # Rotation
    out[..., 0:3, 0:3] = r2om(r)

    # Translation
    out[..., 0:3, -1] = t

    return out

def mat2rt(mat):
    """
    Convert matrix representation to rotation-translation representation.
    mat has shape (..., 4, 4).
    Output is r, t.
    r has shape (..., 4).
    t has shape (..., 3).
    """
    r = om2r(mat[..., 0:3, 0:3])
    t = mat[..., 0:3, -1]

    return r, t



##################################################################################
#### Conversion: Dual quaternion <-> Matrix
##################################################################################
def dqmat2mat(dqmat):
    """
    Convert DQmat representation of SE(3) elements to matrix represntation.
    """
    return rt2mat(*dq2rt(dqmat2dq((dqmat))))

def mat2dqmat(mat):
    """
    Convert matrix representation of SE(3) elements to DQmat represntation.
    """
    return dq2dqmat(rt2dq(*mat2rt(mat)))



##################################################################################
#### Conversion: Axis-Angle <-> Quaternion
##################################################################################
def aa2q(axis, angle):
    """
    Convert axis-angle representation to unit quaternion with non-negative real part.
    axis has shape (..., 3).
    angle has shape (..., 1).
    Output has shape (..., 4).
    """
    shape = (*axis.shape[:-1], 4)
    out = np.zeros(shape, dtype=np.double)

    # r
    out[..., 0] = np.cos(0.5*angle[..., 0])
    # i, j, k
    out[..., 1::] = axis
    out[..., 1::] *= np.sin(0.5*angle)

    # Ensure the real part is non-negative
    out *= (2*(out[..., 0]>=0)-1).reshape((*out.shape[:-1], 1))

    return out

def q2aa(q):
    """
    Convert quaternion with non-negative real part to axis-angle representation.
    q has shape (..., 4).
    Output is axis, angle.
    axis has shape (..., 3).
    angle has shape (..., 1).
    """
    axis = q[..., 1::]
    axis /= sla.norm(q[..., 1::], axis=-1, keepdims=True)

    angle = 2*np.arccos(q[..., 0]).reshape((*axis.shape[:-1], 1))

    return axis, angle



##################################################################################
#### Projection of dual quaternion onto SE(3)
##################################################################################
def conjugate(dqmat):
    """
    Calculate the conjugate of a dual quaternion in DQmat format.
    """
    dqmat[..., 0:2, 0:2] = dqmat[..., 0:2, 0:2].conj()
    dqmat[..., 2::, 2::] = dqmat[..., 2::, 2::].conj()

    dqmat[..., 0:2, 2::] *= -1
    dqmat[..., 2::, 0:2] *= -1

    return dqmat

def normSquared(dqmat):
    """
    Calculate the dual quaternion norm squared, i.e., dq*conj[dq].
    dqmat has shape (..., 4, 4).
    """
    dqmatconj = conjugate(dqmat.copy())
    return dqmat @ dqmatconj

def _dnsqrt(dqmat):
    """
    Calculate the square root of dual number, embedded in the dual quaternion algebra and represented in DQmat format.
    dqmat has shape (..., 4, 4).
    For a dual number, the diagonal is identical and is the real part of the dual number. The second diagonal contains the dual part.
    """
    # Handle the real part
    dqmat[..., 0, 0] = np.sqrt(dqmat[..., 0, 0])
    for i in range(1, 4):
        dqmat[..., i, i] = dqmat[..., 0, 0]

    # Handle the dual part
    dual_indices = [(0, 1), (2, 3)]
    dqmat[..., 0, 1] = 0.5*dqmat[..., 0, 1]/dqmat[..., 0, 0]
    dqmat[..., 2, 3] = dqmat[..., 0, 1]

    return dqmat

def norm(dqmat):
    """
    Calculate the dual quaternion norm, i.e., sqrt(dq*conj[dq]).
    dqmat has shape(..., 4, 4).
    """
    return _dnsqrt(normSquared(dqmat))

def _dninv(dqmat):
    """
    Invert dual numbers, embedded in the dual quaternion algebra and represented in DQmat format.
    dqmat has shape (..., 4, 4).
    For a dual number, the diagonal is identical and is the real part of the dual number. The second diagonal contains the dual part.
    """
    out = np.zeros_like(dqmat)

    # Handle the real part
    out[..., 0, 0] = 1/dqmat[..., 0, 0]
    for i in range(1, 4):
        out[..., i, i] = out[..., 0, 0]

    # Handle the dual part
    out[..., 0, 1] = -dqmat[..., 0, 1]/np.power(dqmat[..., 0, 0], 2)
    out[..., 2, 3] = out[..., 0, 1]

    return out

def normalize(dqmat):
    """
    Return the dual quaternion having norm of 1 by calculating dq/norm(dq)
    dqmat has shape (..., 4, 4).
    """
    n = _dninv(norm(dqmat))

    return dqmat @ n



##################################################################################
#### Erdos-Renyi graph
##################################################################################
def randomERGraph(n, p):
    """
    Generates the adjacency matrix of an Erdos-Renyi graph of size n. The probability of an edge to exist is p.
    Output has shape (n, n).
    """
    A = st.bernoulli.rvs(p, size=(n, n))
    A = np.triu(A, 1)
    A += A.T
    np.fill_diagonal(A, 1)

    return A

def applyERGraph(adjmat, blkmat, block_size = 4):
    """
    Multiply the adjacency matrix of an Erdos-Renyi graph of n vertices by m*n x m*n matrix blkmat, with m=block_size``.
    The (i, j) block of the output is the (i, j) block of blkmat multiplied by adjmat[i, j].
    """
    return np.kron(adjmat, np.ones((block_size, block_size))) * blkmat



##################################################################################
#### Random observation generation
##################################################################################
def _uniformAxis(shape):
    """
    Uniform vector from the 3 sphere.
    """
    out = st.norm.rvs(size=(*shape, 3))
    out /= sla.norm(out, axis=-1, keepdims=True)
    return out

def _uniformAngle(shape,uprange):
    """
    Uniform angle from the interval [0, 2*pi].
    """
    return st.uniform.rvs(loc=0, scale=uprange, size=(*shape, 1))

def _gaussianAngle(shape, sigma=1, mean=0):
    """
    Angle sampled from the normal distribution with given mean and standard deviation sigma.
    """
    return st.norm.rvs(loc=mean, scale=sigma, size=(*shape, 1))


def _gaussianTranslation(shape, sigma=1, mean=0):
    """
    Translation with i.i.d. coordinates with normal distribution with given mean and standard deviation sigma.
    """
    return st.norm.rvs(loc=mean, scale=sigma, size=(*shape, 3))

def randomUniformGaussian(shape, sigma=1):
    """
    Sample rotation (in quaternion format) from the unifrom distribution and translation from the uniform distribution with zero mean and standard deviation sigma.
    Output is r, t.
    r has shape (*shape, 4).
    t has shape (*shape, 3).
    """
    # Rotation
    axis = _uniformAxis(shape)
    angle = _uniformAngle(shape, 2* np.pi)
    r = aa2q(axis, angle)

    # Translation
    t = _gaussianTranslation(shape, sigma=sigma)

    return r, t

def randomUniformGaussianNoise(shape, uprange, sigma_t = 0):
    """
    Sample rotation (in quaternion format) from the unifrom distribution and translation from the uniform distribution with zero mean and standard deviation sigma.
    Output is r, t.
    r has shape (*shape, 4).
    t has shape (*shape, 3).
    """
    # Rotation
    axis = _uniformAxis(shape)
    angle = _uniformAngle(shape, uprange=uprange)
    r = aa2q(axis, angle)

    # Translation
    t = _gaussianTranslation(shape, sigma=sigma_t)

    return r, t


def randomGaussianGaussian(shape, sigma_r=1, sigma_t=1):
    """
    Sample rotation from the Gaussian distriubtion on the translation and on the rotation.
    Output is r, t.
    r has shape (*shape, 4).
    t has shape (*shape, 3).
    """
    # Rotation
    axis = _uniformAxis(shape)
    angle = _gaussianAngle(shape, sigma=sigma_r)
    r = aa2q(axis, angle)

    # Translation
    t = _gaussianTranslation(shape, sigma=sigma_t)

    return r, t



##################################################################################
#### Alignment
##################################################################################
def calculateBestAligner(estimate, original):
    """
    Calcualte the best aligner, the dual quaternion (in DQmat representation) dq satisfying minimizing the least-squares distance between the rotational parts and translational parts.
    estimate, original have shape (n, 1, 4, 4).
    Output is in DQmat format. It has shape (4, 4).
    """
    # Prepare the dq's
    tmp = estimate.copy()
    tmp = conjugate(tmp)
    tmp = tmp @ original
    tmp = dqmat2dq(tmp)
    rotpart = q2m2(tmp[..., 0:4])
    transpart = q2m2(tmp[..., 4::])
    transpart = 2*(transpart @ rotpart.conj().transpose((0, 1, 3, 2)))

    # Calculate the rotational part
    q = np.mean(rotpart, axis=0).squeeze()
    # Normalize the quaternion
    q_norm = 0.5 * np.sum(np.diagonal(q @ q.conj().T)).real
    q /= np.sqrt(q_norm)
    if q[0, 0].real < 0:
        q *= -1


    # Calculate the translational part
    t = np.mean(transpart, axis=0)

    # Construct the output
    dq = np.zeros((1, 8), dtype=np.double)
    dq[0, 0:4] = m22q(q)
    dq[0, 4:] = m22q(0.5 * (t @ q))
    # Convert the output to DQmat
    dq = dq2dqmat(dq)

    return dq



##################################################################################
#### Dual quaternion power iteration
##################################################################################
def _normVector(dqvec):
    """
    Caculate the norm of a vector of dual qauaternions. It is defined thus:
               n
        sqrt{ Sum { conj(dqvec[j]) * dqvec[j] } }
              j=0
    Here, sqrt is the dual number squre root.
    """
    out = np.sum(conjugate(dqvec.copy()) @ dqvec, axis=0, keepdims=True)
    return _dnsqrt(out)

def _normalizeVector(dqvec):
    """
    Normalize a vector of dual quaternions by dividing it by the dual number formed by
               n
        sqrt{ Sum { conj(dqvec[j]) * dqvec[j] } }
              j=0
    The vector is saved in block matrix format with shape (4*s, 1). Every block is in DQmat format.
    """
    tmp = bm2mb(dqvec.copy())
    nn = _dninv(_normVector(tmp))
    out = tmp @ nn
    return mb2bm(out)


def dqpower(mat, max_iter = 100, x0 = None):
    """
    Run the dual quaternion power iteration on the Hermitian matrix of dual quaternions mat.
    mat is has shape (4*s, 4*s). Every 4x4 block is a dual quaternion in DQmat format.
    Output is an estimate of the dual quaternion eigenvector of the largest eigenvalue. It has shape (4*s, 4). Every 4x4 is a dual quaternion in DQmat format.
    Iteration stops after max_iter iterations.
    """
    # Generate initial guess
    if x0 is None:
        shape = (mat.shape[0]//4, 1)
        out = mb2bm(dq2dqmat(rt2dq(*randomUniformGaussian(shape))))
        out = _normalizeVector(out)
    else:
        out = x0

    for n in range(max_iter):
        out = mat @ out
        out = _normalizeVector(out)

    return out



##################################################################################
#### Method of Arrigoni et al (2016)
##################################################################################
def _arrigoni_calc_eigenvector(Y):
    """
    Calculate the leading four eigenvectors of Y.
    """
    return ssl.eigs(Y, k=4)[1]

def _arrigoni_partial_proj(G):
    """
    Project the ground truths onto the space of 4n x 1 block matrices in which every 4'th row is (0, 0, 0, 1).
    This function mimics the behavior of the code of Arrigoni et al, 2016.
    """
    # Calculate the projection coefficients
    U = G[3::4]
    u, s, vh = sla.svd(U)
    # Treat the last three right singular vectors as the null space of U
    alpha = vh.T[:, -3::]
    # Treat the first left singular vector as image of U
    beta = (1/s[0]) * (vh.T[:, 0].reshape(-1, 1) * np.sum(u[:, 0]))
    P = np.concatenate((alpha, beta), axis=1)

    # Carry out the initial projection
    G = G @ P

    # Take the real part
    G = np.real(G)

    # Esnure the determinant of the the first block in G is positive
    # Arrigoni et al do that in their implementation, for some reason
    U = G[0:3, 0:3]
    if sla.det(U) < 0:
        G[:, [0, 1]] = G[:, [1, 0]]

    return G

def __rounder(R):
    """
    Round the 3 x 3 matrix R.
    """
    [u, _, vh] = sla.svd(R.squeeze())

    return u @ np.diag((1, 1, sla.det(u @ vh))).astype(R.dtype.type) @ vh

def _round_rotations(R):
    """
    Round the rotations in R.
    R is an n x 3 x 3.
    Output is an n x 3 x 3 array such that out[j, :, :] is the rounding of R[j, :, :].
    """
    roundfunc = np.vectorize(__rounder, signature="(3, 3)->(3, 3)")
    R = roundfunc(np.split(R, R.shape[0], axis=0)).squeeze()
    return R

def _arrigoni_round(G):
    """
    Round the block matrix G so that its blocks are elemetns of SE(3).
    This function mimics the behavior of the code of Arrigoni et al, 2016.
    """
    G = _arrigoni_partial_proj(G)

    n = G.shape[0] // 4
    blocks = bm2mb(G)

    # Round rotations
    blocks[..., 0:3, 0:3] = np.expand_dims(_round_rotations(blocks[..., 0:3, 0:3]), axis=1)

    # Round translations
    blocks[..., :, -1] /= np.expand_dims(blocks[..., -1, -1], axis=-1)

    # Ensure last row of block is (0, 0, 0, 1)
    blocks[..., -1, 0:3] = np.zeros((1, 3), dtype=G.dtype.type)

    return mb2bm(blocks)

def arrigoni(Y):
    """
    Apply the method of Arrigoni et al (2016) to the observation matrix Y.
    """

    # Calculate the four leading eigenvectors of Y
    G = _arrigoni_calc_eigenvector(Y)

    # Round the blocks, i.e., project them onto SE(3)
    G = _arrigoni_round(G)

    return G



##################################################################################
#### IRLS
##################################################################################
def _stopping_criteria(current_estimate, previous_estimate, irls_iter_no, max_irls_iter, tol, verbose=False):
    """
    Returns True if it's time to stop the iteration. Otherwise, return False.
    current_estimate and previous_estimate are in DQmat foramt. They have shape (n, 1, 4, 4).
    irls_iter_no and max_irls_iter are scalars.
    """
    # Iteration number based stopping criteria
    if irls_iter_no >= max_irls_iter:
        return True

    # Correlation-based stopping criteria
    n = current_estimate.shape[0]
    corr = np.mean(current_estimate @ conjugate(previous_estimate.copy()), axis=0, keepdims=True)
    corr = dqmat2dq(corr @ conjugate(corr.copy()))
    one = np.zeros(8)
    one[0] = 1
    corr = sla.norm(one - corr)/n
    if verbose:
        print("Iteration {:d}: ||1-corr|| = {:.4f}".format(irls_iter_no+1, corr))
    if corr < tol:
        return True

    return False

def _dq_irls_weights(Y, estimate):
    """
    Calculate the weights of the dual quaternion IRLS scheme.
    estimate is in DQmat format. It has shape (n, 1, 4, 4).
    Output has shape (n, n). It is symmetric.
    The weight is just the dual quaternion distance (as an R^{8} vector).
    """
    # Tuning constant
    theta = 2

    # Calculate the absolute difference
    Y_clean = dqmat2dq(_cleanObservationMatrix(estimate))
    Y_clean -= dqmat2dq(Y.copy())
    n = Y.shape[0]
    W = sla.norm(Y_clean, axis=-1).reshape(n, n)
    c = 1.482*np.median(np.abs(W - np.median(W))) * theta
    W = 1/(1+np.power(W/c, 2))
    return W

def _arrigoni_irls_weights(Y, estimate):
    """
    Calculate the weights of the IRLS scheme of the method of Arrigoni et al (2016).
    estimate is in DQmat format. It has shape (n, 1, 4, 4).
    Output has shape (n, n). It is symmetric.
    The weight is just the dual quaternion distance (as an R^{8} vector).
    """
    # Tuning constant
    theta = 2

    # Calculate the absolute difference
    Y_clean = dqmat2mat(_cleanObservationMatrix(estimate.copy()))
    Y_clean -= Y
    n = Y.shape[0]
    W = sla.norm(Y_clean, axis=(-2, -1), ord="fro").reshape(n, n)
    c = 1.482*np.median(np.abs(W - np.median(W))) * theta
    W = 1/(1+np.power(W/c, 2))
    return W

def _irls_applyWeights(Y, weights):
    """
    Apply the weights to the matrix Y.
    Returns the weighted matrix.
    """
    return applyERGraph(weights, Y)

def arrigoni_irls(Y, x0 = None, irls_max_iter = 15, stopping_tol = 4*1e-3, verbose = False):
    """
    An IRLS scheme using the method of Arrigoni et al (2016).
    """
    # Run an initial iteration
    Y_w = Y.copy()
    Y_mb = bm2mb(Y.copy())
    if x0 is None:
        previous_estimate = mat2dqmat(bm2mb(arrigoni(Y_w)))
    else:
        previous_estimate = x0.copy()
        if previous_estimate.ndim <= 2:
            previous_estimate = mat2dqmat(bm2mb(previous_estimate))
    first_estimate = previous_estimate.copy()

    # Run the remaining iterations
    for i in range(irls_max_iter):
        weights = _arrigoni_irls_weights(Y_mb, previous_estimate)
        Y_w = _irls_applyWeights(Y, weights)
        current_estimate = mat2dqmat(bm2mb(arrigoni(Y_w)))

        if _stopping_criteria(current_estimate, previous_estimate, i, irls_max_iter, tol=stopping_tol, verbose=verbose):
            break
        previous_estimate = current_estimate.copy()

    if x0 is None:
        return mb2bm(first_estimate), mb2bm(current_estimate)
    else:
        return mb2bm(current_estimate)

def _dqpower_wnorm(Y, max_iter):
    """
    Carry out the power iteration and then project the result onto SE(3).
    """
    estimate = dqpower(Y, max_iter = max_iter)
    estimate = bm2mb(estimate)
    estimate = normalize(estimate)
    return estimate

def dqpower_irls(Y, x0 = None, max_iter = 20, irls_max_iter = 15, stopping_tol = 1e-4, verbose=False):
    """
    An IRLS scheme using the dual quaternion power iteration.
    Returns the first and last estimate, in this order. Both are in block matrix format.
    """
    # Run an initial iteration
    Y_w = Y.copy()
    Y_mb = bm2mb(Y.copy())
    if x0 is None:
        previous_estimate = _dqpower_wnorm(Y_w, max_iter)
    else:
        previous_estimate = x0.copy()
        if previous_estimate.ndim <= 2:
            previous_estimate = bm2mb(previous_estimate)
    first_estimate = previous_estimate.copy()

    # Run the remaining iterations
    for i in range(irls_max_iter):
        weights = _dq_irls_weights(Y_mb, previous_estimate)
        Y_w = _irls_applyWeights(Y, weights)
        current_estimate = _dqpower_wnorm(Y_w, max_iter=max_iter)

        if _stopping_criteria(current_estimate, previous_estimate, i, irls_max_iter, tol=stopping_tol, verbose=verbose):
            break
        previous_estimate = current_estimate.copy()

    if x0 is None:
        return mb2bm(first_estimate), mb2bm(current_estimate)
    else:
        return mb2bm(current_estimate)



##################################################################################
#### Experiment
##################################################################################
# Constants
FIELDS = ["n", "sigma_r", "sigma_t", "p", "q", "rep_no", "mean_rotation_error", "mean_translation_error"]

def angle2radians(angle):
    return angle*np.pi/180

def getnow():
    """
    Print a formated string of date and time now.
    """
    return datetime.now().strftime("%Y-%m-%d-%H%M%S")

def openCSVFile(dirpath, test_name, writeheader=True, fields = FIELDS):
    """
    Open the output CSV file.
    Returns the file handler and a DictWriter object.
    """
    fn = dirpath + "/" + test_name + "-" + getnow() + ".csv"
    f = open(fn, 'w', newline='')
    dw = csv.DictWriter(f, delimiter=',', fieldnames=fields)
    if writeheader:
        dw.writeheader()

    return f, dw

def logmsg(msg):
    """
    Write a message as output.
    """
    print(getnow(), " ::: ", msg)


def generateGroundTruth(n):
    """
    Generate the ground truth in DQmat format.
    Output has shape (4*n, 4).
    """
    r, t = randomUniformGaussian((n, 1))
    dqmat = dq2dqmat(rt2dq(r, t))

    return dqmat

def _hermitizeDQmat(mat):
    """
    Make the given matrix Hermitian in the dual quaternion sense.
    mat is in matrix of blocks format and is square.

    keep the upper triangular elements of the first two dims, conjugate them(matrix) to the lower triangular elements.

    """
    n = mat.shape[0]
    inds = np.triu_indices(n, k=1)
    mat[inds[1], inds[0], :, :] = conjugate(mat[inds[0], inds[1], :, :])

    return mat

def _cleanObservationMatrix(original):
    """
    Calcualte the clean observation matrix.
    Original is in DQmat format. It has shape (n, 1, 4, 4).
    Output is in DQmat format. It has shape (n, n, 4, 4).
    """
    return original @ conjugate(original.copy()).transpose((1, 0, 2, 3))



def calculateError(estimate, original):
    """
    Calculate the rotational and translation estimation error.
    The function first aligns the estimate to the original and then calculates the error.
    estimate, original have shape (n, 1, 4, 4).
    Output are arrays rerr and terr containing the entry-wise rotation and translational errors, respectively.
    """
    # Calculate best aligner
    ba = calculateBestAligner(estimate, original)
    # Apply best aligner to estimate
    estimate = estimate @ ba

    # Extract the rotation and translation part
    dqmat2rt = lambda x: dq2rt(dqmat2dq(x))
    r_e, t_e = dqmat2rt(estimate)
    r_o, t_o = dqmat2rt(original)

    # Rotational part
    r_e /= sla.norm(r_e, axis=-1, keepdims=True)
    r_o /= sla.norm(r_o, axis=-1, keepdims=True)
    rerr = 2*np.sum(r_e*r_o, axis=-1)**2 - 1
    rerr[rerr < -1] = -1
    rerr[rerr > 1] = 1
    rerr = 2*np.arccos(rerr)

    # Translational part
    terr = sla.norm(t_e - t_o, ord=2, axis=-1)

    return rerr.flatten(), terr.flatten()




def generateObservations_addi(gt_dqmat, sigma_r, sigma_t,p):
    """
    Generate the observation matrix.
    gt_dqmat is the ground truth in dqmat, matrix of blocks format.
    """
    n = gt_dqmat.shape[0]

    # Generate presence graph
    exist_graph = randomERGraph(n, p)

    # Generate the full, clean observation matrix in DQmat representation
    Y1_dqmat = _cleanObservationMatrix(gt_dqmat)

    # print('Y1_dqmat:', Y1_dqmat.shape)

    # Generate a full, dual quaternion Hermitian noise matrix in DQmat representation
    noise = dq2dqmat(rt2dq(*randomGaussianGaussian((n, n), sigma_r=sigma_r, sigma_t=sigma_t)))
    noise0 = dq2dqmat(rt2dq(*randomGaussianGaussian((n, n), sigma_r=0, sigma_t=0)))


    # Apply noise to measurements and hermitize the resulting matrix
    Y1_dqmat = Y1_dqmat + noise - noise0  #question
    Y1_dqmat = _hermitizeDQmat(Y1_dqmat)

    # Convert all matrices of blocks to block matrices
    Y1_mat = mb2bm(rt2mat(*dq2rt(dqmat2dq(Y1_dqmat))))
    Y1_dqmat = mb2bm(Y1_dqmat)
    Y_dqmat = applyERGraph(exist_graph,Y1_dqmat)
    Y_mat = applyERGraph(exist_graph, Y1_mat)


    return Y_dqmat, Y_mat


def generateObservations_multi(gt_dqmat, sigma_r, sigma_t, p, q):
    """
    Generate the observation matrix.
    gt_dqmat is the ground truth in dqmat, matrix of blocks format.
    sigma_r sigma_t is the standard deviation of the rotational and translational noise, respectively.
    p is the probability that a measurement will be present.
    q is the probability a measurement will not be corrupted.
    Outputs two observations matrices, one of dual quaternions (in DQmat format) and another in matrix representation for the method of Arrigoni (2016).
    """
    n = gt_dqmat.shape[0]

    # Generate presence graph
    exist_graph = randomERGraph(n, p)

    # Generate non-corrupted entries graph
    corr_graph = randomERGraph(n, q)

    # Generate the full, clean observation matrix in DQmat representation
    Y1_dqmat = _cleanObservationMatrix(gt_dqmat)

    # Generate a full, dual quaternion Hermitian noise matrix in DQmat representation
    # (n,n,4,4)array with the diags of the first two dims all identity matrix
    noise = dq2dqmat(rt2dq(*randomGaussianGaussian((n, n), sigma_r=sigma_r, sigma_t=sigma_t)))
    inds = np.diag_indices(n)
    noise[inds[0], inds[1], :, :] = np.eye(4)


    # Apply noise to measurements and hermitize the resulting matrix
    Y1_dqmat = Y1_dqmat @ noise
    Y1_dqmat = _hermitizeDQmat(Y1_dqmat)

    # Convert full measurement matrix to matrix representation
    Y1_mat = rt2mat(*dq2rt(dqmat2dq(Y1_dqmat)))

    # Generate corrupted entries matrix
    r, t = randomUniformGaussian((n, n))
    corr_dqmat = dq2dqmat(rt2dq(r, t))
    corr_dqmat = _hermitizeDQmat(corr_dqmat)
    corr_mat = rt2mat(*dq2rt(dqmat2dq(corr_dqmat)))

    # Convert all matrices of blocks to block matrices
    Y1_dqmat = mb2bm(Y1_dqmat)
    Y1_mat = mb2bm(Y1_mat)
    corr_dqmat = mb2bm(corr_dqmat)
    corr_mat = mb2bm(corr_mat)

    # Apply the graphs
    apply_graphs = lambda M, C: applyERGraph(exist_graph, \
                           applyERGraph(corr_graph, M) + applyERGraph(1-corr_graph, C))
    Y_dqmat = apply_graphs(Y1_dqmat, corr_dqmat)
    Y_mat = apply_graphs(Y1_mat, corr_mat)


    return Y_dqmat, Y_mat


















