import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from diffusion_crf.base import *
from plum import dispatch
import diffusion_crf.util as util
from diffusion_crf.matrix.tags import Tags, TAGS

__all__ = ['AbstractMatrix',
           'make_parametric_symmetric_matrix']

################################################################################################################

class AbstractMatrix(AbstractBatchableObject, abc.ABC):
  tags: eqx.AbstractVar[Tags]

  __fix_tags__ = False

  @auto_vmap
  def _force_fix_tags(self) -> 'AbstractMatrix':
    # Create the zero matrix
    zero = self.zeros_like(self)
    mat = util.where(self.tags.is_zero, zero, self)

    # Create the inf matrix
    inf = self.inf_like(self)
    mat = util.where(self.tags.is_inf, inf, mat)
    return mat

  def fix_to_tags(self) -> 'AbstractMatrix':
    """This can help debugging but is not actually necessary because
    symbolic evaluation will automatically fix the tags."""
    if self.__fix_tags__ == False:
      return self
    return self._force_fix_tags()

  def cast_like(self, other: 'AbstractMatrix') -> 'AbstractMatrix':
    """Cast this matrix to be the same type as another matrix.

    This is a simple implementation that uses matrix addition with zeros.
    For more sophisticated casting options, use the `cast_matrix` function.

    Args:
        other: The matrix to cast like

    Returns:
        This matrix cast to the same type as other
    """
    return self + self.zeros_like(other)

  def cast_to(self, target_type: Union[Type['AbstractMatrix'], 'AbstractMatrix'],
             allow_downcast: bool = False) -> 'AbstractMatrix':
    """Cast this matrix to a specified type.

    This uses the casting logic from the cast module, which handles
    various matrix types according to their precedence.

    Args:
        target_type: The target matrix type or instance
        allow_downcast: Whether to allow casting to a lower precedence type

    Returns:
        This matrix cast to the target type
    """
    # Import here to avoid circular imports
    from diffusion_crf.matrix.cast import cast_matrix
    return cast_matrix(self, target_type, allow_downcast)

  @classmethod
  def zeros_like(cls, other: 'AbstractMatrix') -> 'AbstractMatrix':
    zero = super().zeros_like(other)
    return eqx.tree_at(lambda x: x.tags, zero, TAGS.zero_tags)

  @classmethod
  def inf_like(cls, other: 'AbstractMatrix') -> 'AbstractMatrix':
    # The matrix will mostly look like zeros
    zero = super().zeros_like(other)

    # Set all of the values to inf
    params, static = eqx.partition(zero, eqx.is_inexact_array)
    params = jtu.tree_map(lambda x: jnp.inf*jnp.ones_like(x), params)
    inf = eqx.combine(params, static)

    # Set the tags to inf
    inf = eqx.tree_at(lambda x: x.tags, inf, TAGS.inf_tags)
    return inf

  def set_eye(self) -> 'AbstractMatrix':
    out = self.eye(self.shape[0])
    mat = eqx.tree_at(lambda x: x.tags, out, TAGS.eye_tags)
    return mat.fix_to_tags()

  def set_symmetric(self) -> 'AbstractMatrix':
    # out = 0.5*(self + self.T)
    mat = self
    if 'is_symmetric' in mat.tags.__dict__:
      mat = eqx.tree_at(lambda x: x.tags.is_symmetric, mat, jnp.array(True))
    return mat.fix_to_tags()

  def set_zero(self) -> 'AbstractMatrix':
    out = self.zeros_like(self)
    mat = eqx.tree_at(lambda x: x.tags, out, TAGS.zero_tags)
    return mat.fix_to_tags()

  def set_inf(self) -> 'AbstractMatrix':
    return self.inf_like(self)

  def jitter(self, amount: Scalar = 1e-8) -> 'AbstractMatrix':
    return self # TODO: Remove this function
    return jax.lax.cond(self.tags.is_nonzero,
                        lambda: self + amount*self.eye(self.shape[0]),
                        lambda: self)

  @property
  def is_zero(self):
    return ~self.tags.is_nonzero

  @property
  def is_inf(self):
    return self.tags.is_inf

  @property
  def is_symmetric(self):
    return self.tags.is_symmetric

  @property
  def is_eye(self):
    return self.tags.is_eye

  @property
  @abc.abstractmethod
  def shape(self):
    pass

  @property
  def ndim(self):
    return len(self.shape)

  @classmethod
  @abc.abstractmethod
  def zeros(cls, shape: Tuple[int, ...]) -> 'AbstractMatrix':
    pass

  @classmethod
  @abc.abstractmethod
  def eye(cls, dim: int) -> 'AbstractMatrix':
    pass

  @abc.abstractmethod
  def as_matrix(self) -> Float[Array, "M N"]:
    pass

  @abc.abstractmethod
  def __neg__(self) -> 'AbstractMatrix':
    pass

  def __repr__(self):
    return f'{type(self).__name__}(\n{self.as_matrix()}\n)'

  def __add__(self, other: 'AbstractMatrix') -> 'AbstractMatrix':
    return mat_add(self, other)

  def __sub__(self, other: 'AbstractMatrix') -> 'AbstractMatrix':
    return mat_add(self, -other)

  def __mul__(self, other: Scalar) -> 'AbstractMatrix':
    other = jnp.array(other)
    return scalar_mul(self, other)

  def __rmul__(self, other: Scalar) -> 'AbstractMatrix':
    other = jnp.array(other)
    return scalar_mul(self, other)

  def __matmul__(self, other: 'AbstractMatrix') -> 'AbstractMatrix':
    return mat_mul(self, other)

  def __truediv__(self, other: Scalar) -> 'AbstractMatrix':
    other = jnp.array(other)
    return scalar_mul(self, 1/other)

  @auto_vmap
  def transpose(self):
    return transpose(self)

  @property
  def T(self):
    return self.transpose()

  @auto_vmap
  def solve(self, other: 'AbstractMatrix') -> 'AbstractMatrix':
    return matrix_solve(self, other)

  @auto_vmap
  def get_inverse(self) -> 'AbstractMatrix':
    return get_matrix_inverse(self)

  @auto_vmap
  def get_log_det(self) -> Scalar:
    return get_log_det(self)

  @auto_vmap
  def get_cholesky(self) -> 'AbstractMatrix':
    return get_cholesky(self)

  @auto_vmap
  def get_exp(self) -> 'AbstractMatrix':
    return get_exp(self)

  @auto_vmap
  def get_svd(self) -> Tuple['AbstractMatrix', 'AbstractMatrix', 'AbstractMatrix']:
    return get_svd(self)

  @auto_vmap
  def make_parametric_symmetric_matrix(self) -> 'AbstractMatrix':
    return make_parametric_symmetric_matrix(self)

################################################################################################################

@dispatch.abstract
def mat_add(A: AbstractMatrix, B: AbstractMatrix) -> AbstractMatrix:
  """Add two matrices.

  **Arguments**:

  - `A` - First matrix
  - `B` - Second matrix

  **Returns**:

  - The sum of A and B
  """
  pass

@dispatch.abstract
def scalar_mul(A: AbstractMatrix, s: Scalar) -> AbstractMatrix:
  """Multiply a matrix by a scalar.

  **Arguments**:

  - `A` - Matrix to be multiplied
  - `s` - Scalar multiplier

  **Returns**:

  - The product of A and s
  """
  pass

@dispatch.abstract
def mat_mul(A: AbstractMatrix, B: AbstractMatrix) -> AbstractMatrix:
  """Multiply two matrices.

  **Arguments**:

  - `A` - First matrix
  - `B` - Second matrix

  **Returns**:

  - The matrix product of A and B
  """
  pass

@dispatch.abstract
def transpose(A: AbstractMatrix) -> AbstractMatrix:
  """Compute the transpose of a matrix.

  **Arguments**:

  - `A` - Matrix to be transposed

  **Returns**:

  - The transpose of A
  """
  pass

@dispatch.abstract
def matrix_solve(A: AbstractMatrix, B: AbstractMatrix) -> AbstractMatrix:
  """Solve the matrix equation AX = B for X.

  **Arguments**:

  - `A` - Coefficient matrix
  - `B` - Right-hand side matrix

  **Returns**:

  - The solution X to AX = B
  """
  pass

@dispatch.abstract
def get_matrix_inverse(A: AbstractMatrix) -> AbstractMatrix:
  """Compute the inverse of a matrix.

  **Arguments**:

  - `A` - Matrix to be inverted

  **Returns**:

  - The inverse of A
  """
  pass

@dispatch.abstract
def get_log_det(A: AbstractMatrix, mask: Optional[Bool[Array, 'D']] = None) -> Scalar:
  """Compute the log determinant of a matrix.

  **Arguments**:

  - `A` - Matrix to compute the log determinant for
  - `mask` - Optional mask to apply before computation

  **Returns**:

  - The log determinant of A
  """
  pass

@dispatch.abstract
def get_cholesky(A: AbstractMatrix) -> AbstractMatrix:
  """Compute the Cholesky decomposition of a matrix.

  **Arguments**:

  - `A` - Positive definite matrix to decompose

  **Returns**:

  - The Cholesky decomposition of A
  """
  pass

@dispatch.abstract
def get_exp(A: AbstractMatrix) -> AbstractMatrix:
  """Compute the matrix exponential of a matrix.
  """
  pass

@dispatch.abstract
def get_svd(A: AbstractMatrix) -> Tuple['AbstractMatrix', 'AbstractMatrix', 'AbstractMatrix']:
  """Compute the SVD of a matrix.
  """
  pass

@dispatch.abstract
def make_parametric_symmetric_matrix(matrix: AbstractMatrix) -> AbstractMatrix:
  """Convert the symmetric matrix into a parametric form so that its
  elements are unconstrained.

  **Arguments**:

  - `matrix` - Symmetric matrix to convert

  **Returns**:

  - The parametric symmetric matrix
  """
  pass

################################################################################################################
################################################################################################################
################################################################################################################

def matrices_equal(A: Union[AbstractMatrix, Float[Array, 'M N']], B: Union[AbstractMatrix, Float[Array, 'M N']]):
  if isinstance(A, AbstractMatrix):
    Amat = A.as_matrix()
  else:
    Amat = A

  if isinstance(B, AbstractMatrix):
    Bmat = B.as_matrix()
  else:
    Bmat = B

  # Replace any nans with something random (but that stays the same)
  key = random.PRNGKey(12380746)
  if jnp.any(jnp.isnan(Amat)) or jnp.any(jnp.isinf(Amat)):
    Amat = random.normal(key, Amat.shape)
  if jnp.any(jnp.isnan(Bmat)) or jnp.any(jnp.isinf(Bmat)):
    Bmat = random.normal(key, Bmat.shape)

  return jnp.allclose(Amat, Bmat)

def matrix_tests(key, A, B):

  A_dense = A.as_matrix()
  B_dense = B.as_matrix()

  # Check transpose
  if matrices_equal(A.T, A_dense.T) == False:
    raise ValueError(f"Transpose test failed.  Expected {A.T}, got {A_dense.T}")

  # Check addition
  C = A + B
  C_dense = A_dense + B_dense
  if matrices_equal(C, C_dense) == False:
    raise ValueError(f"Addition test failed.  Expected {C}, got {C_dense}")

  # # Check addition with a scalar
  # C = A + 1.0
  # C_dense = A_dense + 1.0
  # if matrices_equal(C, C_dense) == False:
  #   raise ValueError(f"Addition test failed.  Expected {C}, got {C_dense}")

  # Check matrix multiplication
  C = A@B.T
  C_dense = A_dense@B_dense.T
  if matrices_equal(C, C_dense) == False:
    raise ValueError(f"Matrix multiplication test failed.  Expected {C}, got {C_dense}")

  # Check matrix vector products
  x = random.normal(key, (A.shape[1],))
  y = A@x
  y_dense = A_dense@x
  if matrices_equal(y, y_dense) == False:
    raise ValueError(f"Matrix vector product test failed.  Expected {y}, got {y_dense}")

  # Check scalar multiplication
  C = 2.0*A
  C_dense = 2.0*A_dense
  if matrices_equal(C, C_dense) == False:
    raise ValueError(f"Scalar multiplication test failed.  Expected {C}, got {C_dense}")

  if A.shape[0] == A.shape[1]:
    # Check the inverse
    A_inv = A.get_inverse()
    A_inv_dense = jnp.linalg.inv(A_dense)
    if matrices_equal(A_inv, A_inv_dense) == False:
      raise ValueError(f"Matrix inverse test failed.  Expected {A_inv}, got {A_inv_dense}")

    # Check solve
    x = random.normal(key, (A.shape[1],))
    y = A.solve(x)
    y_dense = A_inv_dense@x
    if matrices_equal(y, y_dense) == False:
      raise ValueError(f"Matrix solve test failed.  Expected {y}, got {y_dense}")

  # Check the cholesky decomposition
  J = A@A.T
  J_chol = J.get_cholesky()._force_fix_tags()
  J_dense = J.as_matrix()
  J_chol_dense = jnp.linalg.cholesky(J_dense)
  if J.is_zero:
    J_chol_dense = jnp.zeros_like(J_chol_dense)
  if matrices_equal(J_chol, J_chol_dense) == False:
    raise ValueError(f"Cholesky decomposition test failed.  Expected {J_chol}, got {J_chol_dense}")

  # Check the log determinant
  log_det = J.get_log_det()
  log_det_dense = jnp.linalg.slogdet(J_dense)[1]
  if matrices_equal(log_det, log_det_dense) == False:
    raise ValueError(f"Log determinant test failed.  Expected {log_det}, got {log_det_dense}")

  # Check the SVD
  (U, s, V) = J.get_svd()
  U_dense, s_dense, V_dense = jnp.linalg.svd(J_dense)
  if matrices_equal(U, U_dense) == False or matrices_equal(V, V_dense.T) == False or matrices_equal(s, jnp.diag(s_dense)) == False:
    raise ValueError(f"SVD test failed.  Expected {U, s, V}, got {U_dense, s_dense, V_dense}")

def performance_tests(A, B):
  # Basic operations
  C1 = A + B
  C1 = jtu.tree_map(lambda x: x.block_until_ready(), C1)
  C2 = C1 - B
  C2 = jtu.tree_map(lambda x: x.block_until_ready(), C2)
  C3 = 2.0 * C2
  C3 = jtu.tree_map(lambda x: x.block_until_ready(), C3)
  C4 = C3 / 2.0
  C4 = jtu.tree_map(lambda x: x.block_until_ready(), C4)
  C5 = C4 @ B
  C5 = jtu.tree_map(lambda x: x.block_until_ready(), C5)
  C6 = C5.T
  C6 = jtu.tree_map(lambda x: x.block_until_ready(), C6)

  # Single matrix operations
  C7 = C6.get_inverse()
  C7 = jtu.tree_map(lambda x: x.block_until_ready(), C7)
  C8 = C7.get_cholesky()
  C8 = jtu.tree_map(lambda x: x.block_until_ready(), C8)

  # Matrix-vector operations
  x = jnp.ones(A.shape[1])
  y = C8 @ x
  z = C8.solve(x).reshape(-1, 1) @ y.reshape(1, -1)  # Outer product to get matrix
  z = jtu.tree_map(lambda x: x.block_until_ready(), z)  # Force computation to complete

  # Get log determinant (scalar) and convert back to matrix
  log_det = C8.get_log_det()
  log_det = jtu.tree_map(lambda x: x.block_until_ready(), log_det)