from tslearn.backend import instantiate_backend
from tslearn.utils import (
    check_equal_size,
    to_time_series,
    to_time_series_dataset,
    ts_size,
)
from tslearn.metrics import (soft_dtw)
from ns_numba_ops import sdtw

def cdist_ns_dtw(dataset1, dataset2=None, gamma=1.0, be=None, compute_with_backend=False):
    r"""Compute cross-similarity matrix using Soft-DTW metric.

    Soft-DTW was originally presented in [1]_ and is
    discussed in more details in our
    :ref:`user-guide page on DTW and its variants<dtw>`.

    Soft-DTW is computed as:

    .. math::

        \text{soft-DTW}_{\gamma}(X, Y) =
            \min_{\pi}{}^\gamma \sum_{(i, j) \in \pi} \|X_i, Y_j\|^2

    where :math:`\min^\gamma` is the soft-min operator of parameter
    :math:`\gamma`.

    In the limit case :math:`\gamma = 0`, :math:`\min^\gamma` reduces to a
    hard-min operator and soft-DTW is defined as the square of the DTW
    similarity measure.

    Parameters
    ----------
    dataset1 : array-like, shape=(n_ts1, sz1, d) or (n_ts1, sz1) or (sz1,)
        A dataset of time series.
        If shape is (n_ts1, sz1), the dataset is composed of univariate time series.
        If shape is (sz1,), the dataset is composed of a unique univariate time series.
    dataset2 : None or array-like, shape=(n_ts2, sz2, d) or (n_ts2, sz2) or (sz2,) (default: None)
        Another dataset of time series. If `None`, self-similarity of
        `dataset1` is returned.
        If shape is (n_ts2, sz2), the dataset is composed of univariate time series.
        If shape is (sz2,), the dataset is composed of a unique univariate time series.
    gamma : float (default 1.)
        Gamma parameter for Soft-DTW.
    be : Backend object or string or None
        Backend. If `be` is an instance of the class `NumPyBackend` or the string `"numpy"`,
        the NumPy backend is used.
        If `be` is an instance of the class `PyTorchBackend` or the string `"pytorch"`,
        the PyTorch backend is used.
        If `be` is `None`, the backend is determined by the input arrays.
        See our :ref:`dedicated user-guide page <backend>` for more information.
    compute_with_backend : bool, default=False
        This parameter has no influence when the NumPy backend is used.
        When a backend different from NumPy is used (cf parameter `be`):
        If `True`, the computation is done with the corresponding backend.
        If `False`, a conversion to the NumPy backend can be used to accelerate the computation.

    Returns
    -------
    array-like, shape=(n_ts1, n_ts2)
        Cross-similarity matrix.

    Examples
    --------
    # >>> cdist_soft_dtw([[1, 2, 2, 3], [1., 2., 3., 4.]], gamma=.01)
    # array([[-0.01098612,  1.        ],
    #        [ 1.        ,  0.        ]])
    # >>> cdist_soft_dtw([[1, 2, 2, 3], [1., 2., 3., 4.]],
    # ...                [[1, 2, 2, 3], [1., 2., 3., 4.]], gamma=.01)
    array([[-0.01098612,  1.        ],
           [ 1.        ,  0.        ]])

    The PyTorch backend can be used to compute gradients:

    # >>> import torch
    # >>> dataset1 = torch.tensor([[[1.0], [2.0], [3.0]], [[1.0], [2.0], [3.0]]], requires_grad=True)
    # >>> dataset2 = torch.tensor([[[3.0], [4.0], [-3.0]], [[3.0], [4.0], [-3.0]]])
    # >>> sim_mat = cdist_soft_dtw(dataset1, dataset2, gamma=1.0, be="pytorch", compute_with_backend=True)
    # >>> print(sim_mat)
    # tensor([[41.1876, 41.1876],
    #         [41.1876, 41.1876]], grad_fn=<CopySlices>)
    # >>> sim = sim_mat[0, 0]
    # >>> sim.backward()
    # >>> print(dataset1.grad)
    tensor([[[-4.0001],
             [-2.2852],
             [10.1643]],
    <BLANKLINE>
            [[ 0.0000],
             [ 0.0000],
             [ 0.0000]]])

    See Also
    --------
    soft_dtw : Compute Soft-DTW
    cdist_soft_dtw_normalized : Cross similarity matrix between time series
        datasets using a normalized version of Soft-DTW

    References
    ----------
    .. [1] M. Cuturi, M. Blondel "Soft-DTW: a Differentiable Loss Function for
       Time-Series," ICML 2017.
    """  # noqa: E501
    be = instantiate_backend(be, dataset1, dataset2)
    dataset1 = to_time_series_dataset(dataset1, dtype=be.float64, be=be)

    if dataset2 is None:
        dataset2 = dataset1
        self_similarity = True
    else:
        dataset2 = to_time_series_dataset(dataset2, dtype=be.float64, be=be)
        self_similarity = False

    dists = be.empty((dataset1.shape[0], dataset2.shape[0]))

    equal_size_ds1 = check_equal_size(dataset1, be=be)
    equal_size_ds2 = check_equal_size(dataset2, be=be)

    for i, ts1 in enumerate(dataset1):
        if equal_size_ds1:
            ts1_short = ts1
        else:
            ts1_short = ts1[: ts_size(ts1)]
        for j, ts2 in enumerate(dataset2):
            if equal_size_ds2:
                ts2_short = ts2
            else:
                ts2_short = ts2[: ts_size(ts2)]
            if self_similarity and j < i:
                dists[i, j] = dists[j, i]
            else:
                dists[i, j] = sdtw(
                    ts1_short, ts2_short, gamma=gamma
                )

    return dists

if __name__ == "__main__":
    a = cdist_ns_dtw([[1, 2, 2, 3], [1., 2., 3., 4.]],
                   [[1, 2, 2, 3], [1., 2., 3., 4.]], gamma=.01)
    print(a)

