"""Product of Riemannian metrics.

Define the metric of a product manifold endowed with a product metric.
"""

import joblib

import geomstats.backend as gs
import geomstats.errors
import geomstats.vectorization
from geomstats.geometry.riemannian_metric import RiemannianMetric


class ProductRiemannianMetric(RiemannianMetric):
    """Class for product of Riemannian metrics.

    Parameters
    ----------
    metrics : list
        List of metrics in the product.
    default_point_type : str, {'vector', 'matrix'}
        Point type.
        Optional, default: 'vector'.
    n_jobs : int
        Number of jobs for parallel computing.
        Optional, default: 1.
    """

    def __init__(self, metrics, default_point_type="vector", n_jobs=1):
        self.n_metrics = len(metrics)
        dims = [metric.dim for metric in metrics]
        signatures = [metric.signature for metric in metrics]

        sig_pos = sum(sig[0] for sig in signatures)
        sig_neg = sum(sig[1] for sig in signatures)
        super(ProductRiemannianMetric, self).__init__(
            dim=sum(dims),
            signature=(sig_pos, sig_neg),
            default_point_type=default_point_type,
        )

        self.metrics = metrics
        self.dims = dims
        self.signatures = signatures
        self.n_jobs = n_jobs

    def metric_matrix(self, base_point=None, point_type=None):
        """Compute the matrix of the inner-product.

        Matrix of the inner-product defined by the Riemmanian metric
        at point base_point of the manifold.

        Parameters
        ----------
        base_point : array-like, shape=[..., n_metrics, dim] or
            [..., dim]
            Point on the manifold at which to compute the inner-product matrix.
            Optional, default: None.
        point_type : str, {'vector', 'matrix'}
            Type of representation used for points.
            Optional, default: None.

        Returns
        -------
        matrix : array-like, shape=[..., dim, dim] or
        [..., dim + n_metrics, dim + n_metrics]
            Matrix of the inner-product at the base point.

        """
        if point_type is None:
            point_type = self.default_point_type

        base_point = gs.to_ndarray(base_point, to_ndim=2)
        base_point = gs.to_ndarray(base_point, to_ndim=3)

        matrix = gs.zeros([len(base_point), self.dim, self.dim])
        cum_dim = 0
        for i in range(self.n_metrics):
            cum_dim_next = cum_dim + self.dims[i]
            if point_type == "matrix":
                matrix_next = self.metrics[i].metric_matrix(base_point[:, i])
            elif point_type == "vector":
                matrix_next = self.metrics[i].metric_matrix(
                    base_point[:, cum_dim:cum_dim_next, cum_dim:cum_dim_next]
                )
            else:
                raise ValueError(
                    "invalid point_type argument: {}, expected "
                    "either matrix of vector".format(point_type)
                )
            matrix[:, cum_dim:cum_dim_next, cum_dim:cum_dim_next] = matrix_next
            cum_dim = cum_dim_next
        return matrix[0] if len(base_point) == 1 else matrix

    def is_intrinsic(self, point):
        """Test in a point is represented in intrinsic coordinates.

        This method is only useful for `point_type=vector`.

        Parameters
        ----------
        point : array-like, shape=[..., dim]
            Point on the product manifold.

        Returns
        -------
        intrinsic : array-like, shape=[...,]
            Whether intrinsic coordinates are used for all manifolds.
        """
        if self.default_point_type != "vector":
            raise ValueError("Invalid default_point_type: 'vector' expected.")

        if point.shape[-1] == self.dim:
            return True
        if point.shape[-1] == sum(dim + 1 for dim in self.dims):
            return False
        raise ValueError("Input shape does not match the dimension of the manifold")

    @staticmethod
    def _get_method(metric, method_name, metric_args):
        return getattr(metric, method_name)(**metric_args)

    def _iterate_over_metrics(self, func, args, intrinsic=False):

        cum_index = (
            gs.cumsum(self.dims, axis=0)[:-1]
            if intrinsic
            else gs.cumsum(gs.array([k + 1 for k in self.dims]), axis=0)
        )
        arguments = {key: gs.split(args[key], cum_index, axis=1) for key in args.keys()}
        args_list = [
            {key: arguments[key][j] for key in args.keys()}
            for j in range(self.n_metrics)
        ]
        pool = joblib.Parallel(n_jobs=self.n_jobs, prefer="threads")
        out = pool(
            joblib.delayed(self._get_method)(self.metrics[i], func, args_list[i])
            for i in range(self.n_metrics)
        )
        return out

    def inner_product(
        self, tangent_vec_a, tangent_vec_b, base_point=None, point_type=None
    ):
        """Compute the inner-product of two tangent vectors at a base point.

        Inner product defined by the Riemannian metric at point `base_point`
        between tangent vectors `tangent_vec_a` and `tangent_vec_b`.

        Parameters
        ----------
        tangent_vec_a : array-like, shape=[..., dim + 1]
            First tangent vector at base point.
        tangent_vec_b : array-like, shape=[..., dim + 1]
            Second tangent vector at base point.
        base_point : array-like, shape=[..., dim + 1]
            Point on the manifold.
            Optional, default: None.
        point_type : str, {'vector', 'matrix'}
            Type of representation used for points.
            Optional, default: None.

        Returns
        -------
        inner_prod : array-like, shape=[...,]
            Inner-product of the two tangent vectors.
        """
        if base_point is None:
            base_point = gs.empty((self.n_metrics, self.dim))

        if point_type is None:
            point_type = self.default_point_type
        geomstats.errors.check_parameter_accepted_values(
            point_type, "point_type", ["vector", "matrix"]
        )

        tangent_vec_a = gs.to_ndarray(tangent_vec_a, to_ndim=2)
        tangent_vec_b = gs.to_ndarray(tangent_vec_b, to_ndim=2)
        base_point = gs.to_ndarray(base_point, to_ndim=2)
        if point_type == "vector":
            intrinsic = self.is_intrinsic(tangent_vec_b)
            args = {
                "tangent_vec_a": tangent_vec_a,
                "tangent_vec_b": tangent_vec_b,
                "base_point": base_point,
            }
            inner_prod = self._iterate_over_metrics("inner_product", args, intrinsic)
            return gs.sum(gs.stack(inner_prod, axis=1), axis=1)

        inner_products = [
            metric.inner_product(
                tangent_vec_a[..., i, :],
                tangent_vec_b[..., i, :],
                base_point[..., i, :],
            )
            for i, metric in enumerate(self.metrics)
        ]
        return sum(inner_products)

    def exp(self, tangent_vec, base_point=None, point_type=None):
        """Compute the Riemannian exponential of a tangent vector.

        Parameters
        ----------
        tangent_vec : array-like, shape=[..., dim]
            Tangent vector at a base point.
        base_point : array-like, shape=[..., dim]
            Point on the manifold.
            Optional, default: None.
        point_type : str, {'vector', 'matrix'}
            Type of representation used for points.
            Optional, default: None.

        Returns
        -------
        exp : array-like, shape=[..., dim]
            Point on the manifold equal to the Riemannian exponential
            of tangent_vec at the base point.
        """
        if base_point is None:
            base_point = [
                None,
            ] * self.n_metrics

        if point_type is None:
            point_type = self.default_point_type
        geomstats.errors.check_parameter_accepted_values(
            point_type, "point_type", ["vector", "matrix"]
        )

        if point_type == "vector":
            tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=2)
            base_point = gs.to_ndarray(base_point, to_ndim=2)
            intrinsic = self.is_intrinsic(base_point)
            args = {"tangent_vec": tangent_vec, "base_point": base_point}
            exp = self._iterate_over_metrics("exp", args, intrinsic)
            return gs.hstack(exp)

        tangent_vec = gs.to_ndarray(tangent_vec, to_ndim=3)
        base_point = gs.to_ndarray(base_point, to_ndim=3)
        exp = gs.stack(
            [
                self.metrics[i].exp(tangent_vec[:, i], base_point[:, i])
                for i in range(self.n_metrics)
            ],
            axis=1,
        )
        return exp[0] if len(tangent_vec) == 1 else exp

    def log(self, point, base_point=None, point_type=None):
        """Compute the Riemannian logarithm of a point.

        Parameters
        ----------
        point : array-like, shape=[..., dim]
            Point on the manifold.
        base_point : array-like, shape=[..., dim]
            Point on the manifold.
            Optional, default: None.
        point_type : str, {'vector', 'matrix'}
            Type of representation used for points.
            Optional, default: None.

        Returns
        -------
        log : array-like, shape=[..., dim]
            Tangent vector at the base point equal to the Riemannian logarithm
            of point at the base point.
        """
        if base_point is None:
            base_point = [None] * self.n_metrics

        if point_type is None:
            point_type = self.default_point_type
        geomstats.errors.check_parameter_accepted_values(
            point_type, "point_type", ["vector", "matrix"]
        )

        if point_type == "vector":
            point = gs.to_ndarray(point, to_ndim=2)
            base_point = gs.to_ndarray(base_point, to_ndim=2)
            intrinsic = self.is_intrinsic(base_point)
            args = {"point": point, "base_point": base_point}
            logs = self._iterate_over_metrics("log", args, intrinsic)
            logs = gs.concatenate(logs, axis=1)
            return logs

        point = gs.to_ndarray(point, to_ndim=2, axis=0)
        point = gs.to_ndarray(point, to_ndim=3, axis=0)
        base_point = gs.to_ndarray(base_point, to_ndim=2, axis=0)
        base_point = gs.to_ndarray(base_point, to_ndim=3, axis=0)
        logs = gs.stack(
            [
                self.metrics[i].log(point[:, i], base_point[:, i])
                for i in range(self.n_metrics)
            ],
            axis=1,
        )
        return logs


import jax


class ProductSameRiemannianMetric(RiemannianMetric):
    """Jax based vmapped product metircs over the same manifold"""

    def __init__(self, metric, mul, default_point_type="vector", **kwargs):
        self.metric = metric
        self.mul = mul
        self.dim = self.mul * self.metric.dim

        super().__init__(
            dim=self.dim,
            signature=(metric.signature[0] * self.mul, metric.signature[1] * self.mul),
            default_point_type=default_point_type,
        )

    def _iterate_over_metrics(self, func, kwargs, in_axes=-2, out_axes=-2):
        method = getattr(self.metric, func)
        in_axes = []
        args_list = []
        for key, value in kwargs.items():
            if hasattr(value, "reshape"):
                # value : shape=[..., mul*dim] -> shape=[..., mul, dim]
                value = value.reshape((*value.shape[:-1], self.mul, -1))
                in_axes.append(-2)
            else:
                in_axes.append(None)
            args_list.append(value)
        # out = jax.vmap(method, in_axes=in_axes, out_axes=out_axes)(**args)
        # NOTE: Annoyingly cannot pass named arguments with in_axes! https://github.com/google/jax/issues/7465
        out = jax.vmap(method, in_axes=in_axes, out_axes=out_axes)(*args_list)
        out = out.reshape((*out.shape[:out_axes], -1))
        # value : shape=[..., mul, dim] -> shape=[..., mul*dim]
        return out

    def exp(self, tangent_vec, base_point):
        return self._iterate_over_metrics(
            "exp",
            {
                "tangent_vec": tangent_vec,
                "base_point": base_point,
            },
        )

    def log(self, point, base_point=None, point_type=None):
        return self._iterate_over_metrics(
            "log",
            {
                "point": point,
                "base_point": base_point,
                # "point_type": point_type,
            },
        )

    def squared_norm(self, vector, base_point=None):
        return gs.sum(
            self._iterate_over_metrics(
                "squared_norm",
                {
                    "vector": vector,
                    "base_point": base_point,
                },
                out_axes=-1,
            ),
            axis=-1,
        )
