"""
Additional AST nodes for operations on matrices. The nodes in this module
are meant to represent optimization of matrix expressions within codegen's
target languages that cannot be represented by SymPy expressions.

As an example, we can use :meth:`sympy.codegen.rewriting.optimize` and the
``matin_opt`` optimization provided in :mod:`sympy.codegen.rewriting` to
transform matrix multiplication under certain assumptions:

    >>> from sympy import symbols, MatrixSymbol
    >>> n = symbols('n', integer=True)
    >>> A = MatrixSymbol('A', n, n)
    >>> x = MatrixSymbol('x', n, 1)
    >>> expr = A**(-1) * x
    >>> from sympy import assuming, Q
    >>> from sympy.codegen.rewriting import matinv_opt, optimize
    >>> with assuming(Q.fullrank(A)):
    ...     optimize(expr, [matinv_opt])
    MatrixSolve(A, vector=x)
"""

from .ast import Token
from sympy.matrices import MatrixExpr
from sympy.core.sympify import sympify


class MatrixSolve(Token, MatrixExpr):
    """Represents an operation to solve a linear matrix equation.

    Parameters
    ==========

    matrix : MatrixSymbol

      Matrix representing the coefficients of variables in the linear
      equation. This matrix must be square and full-rank (i.e. all columns must
      be linearly independent) for the solving operation to be valid.

    vector : MatrixSymbol

      One-column matrix representing the solutions to the equations
      represented in ``matrix``.

    Examples
    ========

    >>> from sympy import symbols, MatrixSymbol
    >>> from sympy.codegen.matrix_nodes import MatrixSolve
    >>> n = symbols('n', integer=True)
    >>> A = MatrixSymbol('A', n, n)
    >>> x = MatrixSymbol('x', n, 1)
    >>> from sympy.printing.numpy import NumPyPrinter
    >>> NumPyPrinter().doprint(MatrixSolve(A, x))
    'numpy.linalg.solve(A, x)'
    >>> from sympy import octave_code
    >>> octave_code(MatrixSolve(A, x))
    'A \\\\ x'

    """
    __slots__ = _fields = ('matrix', 'vector')

    _construct_matrix = staticmethod(sympify)
    _construct_vector = staticmethod(sympify)

    @property
    def shape(self):
        return self.vector.shape

    def _eval_derivative(self, x):
        A, b = self.matrix, self.vector
        return MatrixSolve(A, b.diff(x) - A.diff(x) * MatrixSolve(A, b))
