"""
Core operations on tensors in Tensor Ring (TR) format
"""

import warnings

import numpy as np

import tensorly as tl
from ._factorized_tensor import FactorizedTensor


def tr_to_tensor(factors):
    """Returns the full tensor whose TR decomposition is given by 'factors'

        Re-assembles 'factors', which represent a tensor in TR format
        into the corresponding full tensor

    Parameters
    ----------
    factors : list of 3D-arrays
              TR factors (TR-cores)

    Returns
    -------
    output_tensor : ndarray
                   tensor whose TR decomposition was given by 'factors'
    """
    full_shape = [f.shape[1] for f in factors]
    full_tensor = tl.reshape(factors[0], (-1, factors[0].shape[2]))

    for factor in factors[1:-1]:
        rank_prev, _, rank_next = factor.shape
        factor = tl.reshape(factor, (rank_prev, -1))
        full_tensor = tl.dot(full_tensor, factor)
        full_tensor = tl.reshape(full_tensor, (-1, rank_next))

    full_tensor = tl.reshape(
        full_tensor, (factors[-1].shape[2], -1, factors[-1].shape[0])
    )
    full_tensor = tl.moveaxis(full_tensor, 0, -1)
    full_tensor = tl.reshape(
        full_tensor, (-1, factors[-1].shape[0] * factors[-1].shape[2])
    )
    factor = tl.moveaxis(factors[-1], -1, 1)
    factor = tl.reshape(factor, (-1, full_shape[-1]))
    full_tensor = tl.dot(full_tensor, factor)
    return tl.reshape(full_tensor, full_shape)


def tr_to_unfolded(factors, mode):
    """Returns the unfolding matrix of a tensor given in TR format

    Reassembles a full tensor from 'factors' and returns its unfolding matrix
    with mode given by 'mode'

    Parameters
    ----------
    factors: list of 3D-arrays
              TR factors
    mode: int
          unfolding matrix to be computed along this mode

    Returns
    -------
    2-D array
    unfolding matrix at mode given by 'mode'
    """
    return tl.unfold(tr_to_tensor(factors), mode)


def tr_to_vec(factors):
    """Returns the tensor defined by its TR format ('factors') into
       its vectorized format

    Parameters
    ----------
    factors: list of 3D-arrays
              TR factors

    Returns
    -------
    1-D array
    vectorized format of tensor defined by 'factors'
    """
    return tl.tensor_to_vec(tr_to_tensor(factors))


def _validate_tr_tensor(tr_tensor):
    factors = tr_tensor
    n_factors = len(factors)

    if n_factors < 2:
        raise ValueError(
            "A Tensor Ring tensor should be composed of at least two factors."
            f"However, {n_factors} factor was given."
        )

    rank = []
    shape = []
    for index, factor in enumerate(factors):
        current_rank, current_shape, next_rank = tl.shape(factor)

        # Check that factors are third order tensors
        if not tl.ndim(factor) == 3:
            raise ValueError(
                "TR expresses a tensor as third order factors (tr-cores).\n"
                f"However, tl.ndim(factors[{index}]) = {tl.ndim(factor)}"
            )

        # Consecutive factors should have matching ranks
        if tl.shape(factors[index - 1])[2] != current_rank:
            raise ValueError(
                "Consecutive factors should have matching ranks\n"
                " -- e.g. tl.shape(factors[0])[2]) == tl.shape(factors[1])[0])\n"
                f"However, tl.shape(factor[{index-1}])[2] == {tl.shape(factors[index-1])[2]} but"
                f" tl.shape(factor[{index}])[0] == {current_rank}"
            )

        shape.append(current_shape)
        rank.append(current_rank)

    # Add last rank (boundary condition)
    rank.append(next_rank)

    return tuple(shape), tuple(rank)


def _tr_n_param(tensor_shape, rank):
    """Number of parameters of a TR decomposition for a given `rank` and full `tensor_shape`.

    Parameters
    ----------
    tensor_shape : int tuple
        shape of the full tensor to decompose (or approximate)

    rank : tuple
        rank of the TR decomposition

    Returns
    -------
    n_params : int
        Number of parameters of a TR decomposition of rank `rank` of a full tensor of shape `tensor_shape`
    """
    factor_params = []
    for i, s in enumerate(tensor_shape):
        factor_params.append(rank[i] * s * rank[i + 1])
    return np.sum(factor_params)


def validate_tr_rank(tensor_shape, rank="same", rounding="round"):
    """Returns the rank of a Tensor Ring Decomposition

    Parameters
    ----------
    tensor_shape : tuple
        shape of the tensor to decompose
    rank : {'same', float, tuple, int}, default is same
        way to determine the rank, by default 'same'
        if 'same': rank is computed to keep the number of parameters (at most) the same
        if float, computes a rank so as to keep rank percent of the original number of parameters
        if int or tuple, just returns rank
    rounding : {'round', 'floor', 'ceil'}

    Returns
    -------
    rank : int tuple
        rank of the decomposition
    """
    if rounding == "ceil":
        rounding_fun = np.ceil
    elif rounding == "floor":
        rounding_fun = np.floor
    elif rounding == "round":
        rounding_fun = np.round
    else:
        raise ValueError(f"Rounding should be round, floor or ceil, but got {rounding}")

    if rank == "same":
        rank = float(1)

    n_dim = len(tensor_shape)
    if n_dim == 2:
        warnings.warn(
            "Determining the TR-rank for the trivial case of a matrix"
            f" (order 2 tensor) of shape {tensor_shape}, not a higher-order tensor."
        )

    if isinstance(rank, float):
        # Choose the *same* rank for each mode
        n_param_tensor = np.prod(tensor_shape) * rank

        # R_k I_k R_{k+1} = R^2 I_k
        solution = int(rounding_fun(np.sqrt(n_param_tensor / np.sum(tensor_shape))))
        rank = (solution,) * (n_dim + 1)

    else:
        # Check user input for potential errors
        n_dim = len(tensor_shape)
        if isinstance(rank, int):
            rank = (rank,) * (n_dim + 1)
        elif n_dim + 1 != len(rank):
            message = (
                "Provided incorrect number of ranks. "
                "Should verify len(rank) == tl.ndim(tensor)+1, "
                f"but len(rank) = {len(rank)} while tl.ndim(tensor)+1 = {n_dim + 1}"
            )
            raise ValueError(message)

        # Check first and last rank
        if rank[0] != rank[-1]:
            message = (
                f"Provided rank[0] == {rank[0]} and rank[-1] == {rank[-1]}"
                " but boundaring conditions dictatate rank[0] == rank[-1]"
            )
            raise ValueError(message)

    return list(rank)


class TRTensor(FactorizedTensor):
    def __init__(self, factors):
        super().__init__()

        # Will raise an error if invalid
        shape, rank = _validate_tr_tensor(factors)

        self.shape = tuple(shape)
        self.rank = tuple(rank)
        self.factors = factors

    def __getitem__(self, index):
        return self.factors[index]

    def __setitem__(self, index, value):
        self.factors[index] = value

    def __iter__(self):
        for index in range(len(self)):
            yield self[index]

    def __len__(self):
        return len(self.factors)

    def __repr__(self):
        message = (
            f"factors list : rank-{self.rank} tensor ring tensor of shape {self.shape}"
        )
        return message

    def to_tensor(self):
        return tr_to_tensor(self)

    def to_unfolding(self, mode):
        return tr_to_unfolded(self, mode)

    def to_vec(self):
        return tr_to_vec(self)
