"""Gaussian and Energy metric implementations in Python.

Anonymized submission for ICLR 2023.

September 28, 2022.

Requires scipy, sklearn, opt_einsum, numpy.
"""

import itertools

import numpy as np
from opt_einsum import contract
from scipy.linalg import sqrtm, orthogonal_procrustes
from scipy.optimize import linear_sum_assignment
import scipy.sparse
from scipy.spatial.distance import squareform
from scipy.stats import ortho_group
from sklearn.utils.validation import check_array
from sklearn.utils.extmath import randomized_svd
from sklearn.utils.validation import check_random_state
from sklearn.metrics.pairwise import pairwise_kernels


def check_equal_shapes(X, Y, nd=2, zero_pad=False):

    X = check_array(X, allow_nd=True)
    Y = check_array(Y, allow_nd=True)

    if (X.ndim != nd) or (Y.ndim != nd):
        raise ValueError(
            "Expected {}d arrays, but shapes were {} and "
            "{}.".format(nd, X.shape, Y.shape)
        )

    if X.shape != Y.shape:

        if zero_pad and (X.shape[:-1] == Y.shape[:-1]):
            
            # Number of padded zeros to add.
            n = max(X.shape[-1], Y.shape[-1])
            
            # Padding specifications for X and Y.
            px = np.zeros((nd, 2), dtype="int")
            py = np.zeros((nd, 2), dtype="int")
            px[-1, -1] = n - X.shape[-1]
            py[-1, -1] = n - Y.shape[-1]

            # Pad X and Y with zeros along final axis.
            X = np.pad(X, px)
            Y = np.pad(Y, py)

        else:
            raise ValueError(
                "Expected arrays with equal dimensions, "
                "but got arrays with shapes {} and {}."
                "".format(X.shape, Y.shape))

    return X, Y


def align(X, Y, group="orth"):
    """
    Return a matrix that optimally aligns 'X' to 'Y'. Note
    that the optimal alignment is the same for either the
    angular distance or the Euclidean distance since all
    alignments come from sub-groups of the orthogonal group.

    Parameters
    ----------
    X : (m x n) ndarray.
        Activation patterns across 'm' inputs and 'n' neurons,
        sampled from the first network (the one which is transformed
        by the alignment operation).

    Y : (m x n) ndarray.
        Activation patterns across 'm' inputs and 'n' neurons,
        sampled from the second network (the one which is fixed).

    group : str
        Specifies the set of allowable alignment operations (a group of
        isometries). Must be one of ("orth", "perm", "identity").

    Returns
    -------
    T : (n x n) ndarray or sparse matrix.
        Linear operator such that 'X @ T' is optimally aligned to 'Y'.
        Note further that 'Y @ T.transpose()' is optimally aligned to 'X',
        by symmetry.
    """

    if group == "orth":
        return orthogonal_procrustes(X, Y)[0]

    elif group == "perm":
        ri, ci = linear_sum_assignment(X.T @ Y, maximize=True)
        n = ri.size
        return scipy.sparse.csr_matrix(
            (np.ones(n), (ri, ci)), shape=(n, n)
        )

    elif group == "identity":
        return scipy.sparse.eye(X.shape[1])

    else:
        raise ValueError(f"Specified group '{group}' not recognized.")


def sq_bures_metric(A, B):
    """
    Slow way to compute the square of the Bures metric between two
    positive-definite matrices.
    """
    va, ua = np.linalg.eigh(A)
    vb, ub = np.linalg.eigh(B)
    sva, svb = np.sqrt(va), np.sqrt(vb)
    return (
        np.sum(va) + np.sum(vb) - 2 * np.sum(
            np.linalg.svd(
                (sva[:, None] * ua.T) @ (ub * svb[None, :]),
                compute_uv=False
            )
        )
    )


def rand_orth(m, n=None, random_state=None):
    """
    Creates a random matrix with orthogonal columns or rows.

    Parameters
    ----------
    m : int
        First dimension

    n : int
        Second dimension (if None, matrix is m x m)

    random_state : int or np.random.RandomState
        Specifies the state of the random number generator.

    Returns
    -------
    Q : ndarray
        An m x n random matrix. If m > n, the columns are orthonormal.
        If m < n, the rows are orthonormal. If m == n, the result is
        an orthogonal matrix.
    """
    rs = check_random_state(random_state)
    n = m if n is None else n

    Q = ortho_group.rvs(max(m, n), random_state=rs)

    if Q.shape[0] > m:
        Q = Q[:m]
    if Q.shape[1] > n:
        Q = Q[:, :n]

    return Q


class GaussianStochasticMetric:

    def __init__(
            self, alpha=1.0, group="orth", init="means", niter=1000, tol=1e-8,
            random_state=None, n_restarts=1
        ):
        """
        alpha : float between 0 and 2
            When alpha == 0, only uses covariance
            When alpha == 1, equals Wasserstein
            When alpha == 2, only uses means (i.e. deterministic metric)
        """

        if (alpha < 0) or (alpha > 2):
            raise ValueError("alpha parameter should be between zero and two.")
        self.alpha = alpha
        self.group = group
        self.init = init
        self.niter = niter
        self.tol = tol
        self._rs = check_random_state(random_state)
        self.n_restarts = n_restarts
        if self.init == "means":
            assert n_restarts == 1

    def fit(self, X, Y):
        means_X, covs_X = X
        means_Y, covs_Y = Y

        assert means_X.shape == means_Y.shape
        assert covs_X.shape == covs_Y.shape
        assert means_X.shape[0] == covs_X.shape[0]
        assert means_X.shape[1] == covs_X.shape[1]
        assert means_X.shape[1] == covs_X.shape[2]

        best_loss = np.inf
        for restart in range(self.n_restarts):

            if self.init == "means":
                init_T = align(means_Y, means_X, group=self.group)
            elif self.init == "rand":
                init_T = rand_orth(means_X.shape[1], random_state=self._rs)

            T, loss_hist = _fit_gaussian_alignment(
                means_X, means_Y, covs_X, covs_Y, init_T,
                self.alpha, self.group, self.niter, self.tol
            )
            if best_loss > loss_hist[-1]:
                best_loss = loss_hist[-1]
                best_T = T

        self.T = best_T
        self.loss_hist = loss_hist
        return self

    def transform(self, X, Y):
        means_Y, covs_Y = Y
        return X, (
            means_Y @ self.T,
            contract("ijk,jl,kp->ilp", covs_Y, self.T, self.T)
        )

    def score(self, X, Y):
        X, Y = self.transform(X, Y)
        mX, sX = X
        mY, sY = Y

        A = np.sum((mX - mY) ** 2, axis=1)
        B = np.array([sq_bures_metric(sx, sy) for sx, sy in zip(sX, sY)])
        mn = np.mean(self.alpha * A + (2 - self.alpha) * B)
        # mn should always be positive but sometimes numerical rounding errors
        # cause mn to be very slightly negative, causing sqrt(mn) to be nan.
        # Thus, we take sqrt(abs(mn)) and pass through the sign. Any large
        # negative outputs should be caught by unit tests.
        return np.sign(mn) * np.sqrt(abs(mn))


class EnergyStochasticMetric:

    def __init__(self, group="orth"):
        self.group = group

    def fit(self, X, Y, niter=100, tol=1e-6):
        # X.shape = (images x repeats x neurons)
        # Y.shape = (images x repeats x neurons)

        assert X.shape == Y.shape

        r = X.shape[1]

        # m = X.shape[0] * X.shape[1]
        # n = X.shape[-1]
        # X = X.reshape(m, n)
        # Y = Y.reshape(m, n)
        idx = np.array(list(itertools.product(range(r), range(r))))
        X = np.row_stack([x[idx[:, 0]] for x in X])
        Y = np.row_stack([y[idx[:, 1]] for y in Y])

        w = np.ones(X.shape[0])
        loss_hist = [np.mean(np.linalg.norm(X - Y, axis=-1))]

        for i in range(niter):
            Q = align(w[:, None] * Y, w[:, None] * X, group=self.group)
            resid = np.linalg.norm(X - Y @ Q, axis=-1)
            loss_hist.append(np.mean(resid))
            w = 1 / np.maximum(np.sqrt(resid), 1e-6)
            if (loss_hist[-2] - loss_hist[-1]) < tol:
                break

        self.w = w
        self.Q = Q
        self.loss_hist = loss_hist

    def transform(self, X, Y):
        # X.shape = (images x repeats x neurons)
        # Y.shape = (images x repeats x neurons)
        assert X.shape == Y.shape
        return X, contract("ijk,kl->ijl", Y, self.Q)

    def score(self, X, Y):
        X, Y = self.transform(X, Y)
        m = X.shape[0] # num images
        n_samples = X.shape[1]

        combs = np.array(list(
            itertools.combinations(range(n_samples), 2)
        ))
        prod = np.array(list(
            itertools.product(range(n_samples), range(n_samples))
        ))

        d_xy, d_xx, d_yy = 0, 0, 0
        for i in range(m):
            d_xy += np.mean(np.linalg.norm(X[i][prod[:, 0]] - Y[i][prod[:, 1]], axis=-1))
            d_xx += np.mean(np.linalg.norm(X[i][combs[:, 0]] - X[i][combs[:, 1]], axis=-1))
            d_yy += np.mean(np.linalg.norm(Y[i][combs[:, 0]] - Y[i][combs[:, 1]], axis=-1))

        return np.sqrt((d_xy / m) - .5*((d_xx / m) + (d_yy / m)))



def _fit_gaussian_alignment(
        means_X, means_Y, covs_X, covs_Y, T, alpha, group, niter, tol
    ):
    vX, uX = np.linalg.eigh(covs_X)
    sX = contract("ijk,ik,ilk->ijl", uX, np.sqrt(vX), uX)

    vY, uY = np.linalg.eigh(covs_Y)
    sY = contract("ijk,ik,ilk->ijl", uY, np.sqrt(vY), uY)

    loss_hist = []

    for i in range(niter):
        Qs = [align(T.T @ sy, sx, group="orth") for sx, sy in zip(sX, sY)]
        A = np.row_stack(
            [alpha * means_X] +
            [(2 - alpha) * sx for sx in sX]
        )
        r_sY = []
        B = np.row_stack(
            [alpha * means_Y] +
            [Q.T @ ((2 - alpha) * sy) for Q, sy in zip(Qs, sY)]
        )
        T = align(B, A, group=group)
        loss_hist.append(np.linalg.norm(A - B @ T))
        if i < 2:
            pass
        elif (loss_hist[-2] - loss_hist[-1]) < tol:
            break

    return T, loss_hist


if __name__ == "__main__":
    np.random.seed(42)
    n_models, n_samples, n_latent = 2, 100, 10
    mu = np.random.randn(n_models, n_samples, n_latent)
    logvars = np.random.randn(n_models, n_samples, n_latent)

    cov = np.zeros((n_models, n_samples, n_latent, n_latent))
    for i in range(n_models):
        cov[i] = np.stack([np.diag(np.exp(logvars[i,j])) for j in range(n_samples)], 0)

    gaussian_metric = GaussianStochasticMetric(alpha=1.)
    gaussian_metric.fit((mu[0], cov[0]), (mu[1], cov[1]))
    print("Gaussian dist:" , gaussian_metric.score((mu[0], cov[0]), (mu[1], cov[1])))

    np.random.seed(42)
    energy_metric = EnergyStochasticMetric()
    n_models, n_images, n_samples, n_latent = 2, 1, 1024, 5
    X = np.random.randn(n_models, n_images, n_samples, n_latent)

    # randomly scale and rotate the 2nd model
    X[1, 0] = X[1, 0] @ np.diag([1,2,3,4,5])@ np.linalg.qr(np.random.randn(n_latent, n_latent))[0]
    energy_metric.fit(X[0], X[1], niter=100, tol=1E-4)
    print("Energy dist", energy_metric.score(X[0], X[1]))

