import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, PyTree, Scalar, Bool
import warnings
import abc
from plum import dispatch, ModuleType
from functools import wraps
from diffusion_crf.matrix import AbstractMatrix
import diffusion_crf.util as util

# AbstractMatrix = ModuleType("diffusion_crf.matrix", "AbstractMatrix")

__all__ = ['psd_check',
           'inverse_check',
           'symbolic_add',
           'symbolic_scalar_mul',
           'symbolic_mat_mul',
           'symbolic_solve',
           'symbolic_matrix_inverse',
           'symbolic_log_det',
           'symbolic_cholesky',
           'symbolic_exp',
           'symbolic_svd']

def psd_check(J: AbstractMatrix):
  # J = eqx.error_if(J, jnp.any(jnp.linalg.eigvalsh(J.as_matrix()) < 0), "Matrix must be positive definite")
  # J = eqx.error_if(J, jnp.any(J.as_matrix() - J.T.as_matrix() != 0), "Matrix must be symmetric for real!")
  # J = eqx.error_if(J, ~J.tags.is_symmetric, "Matrix must be symmetric")
  return J

def inverse_check(A: AbstractMatrix):
  # A = eqx.error_if(A, A.tags.is_zero, "Cannot invert a zero matrix")
  return A

################################################################################################################

def _check_add_types(A: AbstractMatrix):
  # Don't need to do anything here
  return A

# def _symbolic_add(A: AbstractMatrix, B: AbstractMatrix, out: AbstractMatrix) -> AbstractMatrix:
#   out_tags = A.tags.add_update(B.tags)
#   out = eqx.tree_at(lambda x: x.tags, out, out_tags)
#   out = out.fix_to_tags()
#   return out

def symbolic_add(f):

  @wraps(f)
  def f_wrapper(A: AbstractMatrix, B: AbstractMatrix) -> AbstractMatrix:
    # Check the types
    A, B = map(_check_add_types, (A, B))

    # Perform the operation
    out = f(A, B)

    # # Perform the symbolic operation
    # if isinstance(out, AbstractMatrix) == False:
    #   out = _symbolic_add(A, B, out)
    return out
  return f_wrapper

################################################################################################################

def _check_scalar_mul_types(A: AbstractMatrix):
  # Don't need to do anything here
  return A

def _symbolic_scalar_mul(A: AbstractMatrix, s: Scalar, out: AbstractMatrix) -> AbstractMatrix:
  out_tags = A.tags.scalar_mul_update()
  out = eqx.tree_at(lambda x: x.tags, out, out_tags)
  out = out.fix_to_tags()
  return out

def symbolic_scalar_mul(f):

  @wraps(f)
  def f_wrapper(A: AbstractMatrix, s: Scalar) -> AbstractMatrix:
    # Check the types
    A = _check_scalar_mul_types(A)

    # Perform the operation
    out = f(A, s)

    # # Perform the symbolic operation
    # if isinstance(out, AbstractMatrix) == False:
    #   out = _symbolic_scalar_mul(A, s, out)
    return out
  return f_wrapper

################################################################################################################

def _check_mat_mul_types(A: AbstractMatrix):
  # Don't need to do anything here
  return A

# @dispatch
# def _symbolic_mat_mul(A: AbstractMatrix, B: AbstractMatrix, out: AbstractMatrix) -> AbstractMatrix:
#   # Update the tags for the output
#   out_tags = A.tags.mat_mul_update(B.tags)
#   out = eqx.tree_at(lambda x: x.tags, out, out_tags)
#   out = out.fix_to_tags()

#   if A.shape[-1] == A.shape[-2]:
#     # If A is the identity, then return B
#     out = jax.lax.cond(A.tags.is_eye, lambda: B, lambda: out)

#   # If A is zero, then return zero
#   out = jax.lax.cond(A.tags.is_zero, lambda: out.set_zero(), lambda: out)

#   # If A is inf, then return inf
#   out = jax.lax.cond(A.tags.is_inf, lambda: out.set_inf(), lambda: out)

#   if B.shape[-1] == B.shape[-2]:
#     # If B is the identity, then return A
#     out = jax.lax.cond(B.tags.is_eye, lambda: A, lambda: out)

#   # If B is zero, then return zero
#   out = jax.lax.cond(B.tags.is_zero, lambda: out.set_zero(), lambda: out)

#   # If B is inf, then return inf
#   out = jax.lax.cond(B.tags.is_inf, lambda: out.set_inf(), lambda: out)

#   return out

@dispatch
def _symbolic_mat_mul(A: AbstractMatrix, b: Float[Array, 'N'], out: Float[Array, 'M']) -> Float[Array, 'M']:
  if A.shape[-1] == A.shape[-2]:
    # If A is the identity, then return b
    out = util.where(A.tags.is_eye, b, out)

  # If A is zero, then return zero
  out = util.where(A.tags.is_zero, jnp.zeros_like(out), out)

  # If A is inf, then return inf
  out = util.where(A.tags.is_inf, jnp.ones_like(out)*jnp.inf, out)

  return out

def symbolic_mat_mul(f):

  @wraps(f)
  def f_wrapper(A: AbstractMatrix, B: Union[AbstractMatrix, Float[Array, 'N']]) -> Union[AbstractMatrix, Float[Array, 'N']]:
    # Check the types
    A = _check_mat_mul_types(A)

    # Perform the operation
    out = f(A, B)

    # # Perform the symbolic operation
    # if isinstance(out, AbstractMatrix) == False:
    #   out = _symbolic_mat_mul(A, B, out)
    return out
  return f_wrapper

################################################################################################################

def _check_solve_types(A: AbstractMatrix):
  A = inverse_check(A)
  return A

# @dispatch
# def _symbolic_solve(A: AbstractMatrix, B: AbstractMatrix, out: AbstractMatrix) -> AbstractMatrix:
#   # Update the tags for the output
#   out_tags = A.tags.solve_update(B.tags)
#   out = eqx.tree_at(lambda x: x.tags, out, out_tags)
#   out = out.fix_to_tags()

#   if A.shape[-1] == A.shape[-2]:
#     # If A is the identity, then return B
#     out = jax.lax.cond(A.tags.is_eye, lambda: B, lambda: out)

#   # If A is zero, then return inf
#   out = jax.lax.cond(A.tags.is_zero, lambda: out.set_inf(), lambda: out)

#   # If A is inf, then return zero
#   out = jax.lax.cond(A.tags.is_inf, lambda: out.set_zero(), lambda: out)

#   # If B is the identity, then return the inverse of A (don't do anything for this)
#   pass

#   # If B is zero, then return zero
#   out = jax.lax.cond(B.tags.is_zero, lambda: out.set_zero(), lambda: out)

#   # If B is inf, then return inf
#   out = jax.lax.cond(B.tags.is_inf, lambda: out.set_inf(), lambda: out)

#   return out

@dispatch
def _symbolic_solve(A: AbstractMatrix, b: Float[Array, 'N'], out: Float[Array, 'N']) -> Float[Array, 'N']:

  if A.shape[-1] == A.shape[-2]:
    # If A is the identity, then return b
    out = util.where(A.tags.is_eye, b, out)

  # If A is zero, then return inf
  out = util.where(A.tags.is_zero, jnp.ones_like(out)*jnp.inf, out)

  # If A is inf, then return zero
  out = util.where(A.tags.is_inf, jnp.zeros_like(out), out)

  return out

def symbolic_solve(f):

  @wraps(f)
  def f_wrapper(A: AbstractMatrix, B: Union[AbstractMatrix, Float[Array, 'N']]) -> Union[AbstractMatrix, Float[Array, 'N']]:
    # Check the types
    A = _check_solve_types(A)

    # Perform the operation assuming that A is invertible
    out = f(A, B)

    # # If A is inf, zero or eye, then symbolically solve the system
    # if isinstance(out, AbstractMatrix) == False:
    #   out = _symbolic_solve(A, B, out)

    return out

  return f_wrapper

################################################################################################################

def _check_matrix_inverse_types(A: AbstractMatrix):
  A = inverse_check(A)
  return A

def _symbolic_matrix_inverse(A: AbstractMatrix, out: AbstractMatrix) -> AbstractMatrix:
  # If A is the identity, then return the identity
  out = util.where(A.tags.is_eye, A, out)

  # If A is zero, then return inf
  out = util.where(A.tags.is_zero, A.set_inf(), out)

  # If A is inf, then return zero
  out = util.where(A.tags.is_inf, A.set_zero(), out)

  return out

def symbolic_matrix_inverse(f):

  @wraps(f)
  def f_wrapper(A: AbstractMatrix) -> AbstractMatrix:
    # Check the types
    A = _check_matrix_inverse_types(A)

    # Perform the operation
    out = f(A)

    # # Perform the symbolic operation
    # if isinstance(out, AbstractMatrix) == False:
    #   out = _symbolic_matrix_inverse(A, out)
    return out
  return f_wrapper

################################################################################################################

def _check_log_det_types(A: AbstractMatrix):
  return A

def _symbolic_log_det(A: AbstractMatrix, out: Scalar) -> Scalar:

  # If A is zero, then return -inf
  out = util.where(A.tags.is_zero, jnp.ones_like(out)*-jnp.inf, out)

  # If A is inf, then return inf
  out = util.where(A.tags.is_inf, jnp.zeros_like(out)*jnp.inf, out)

  # If A is the identity, then return zero
  out = util.where(A.tags.is_eye, jnp.zeros_like(out), out)

  return out

def symbolic_log_det(f):

  @wraps(f)
  def f_wrapper(A: AbstractMatrix) -> Scalar:
    # Check the types
    A = _check_log_det_types(A)

    # Perform the operation
    out = f(A)

    # # Perform the symbolic operation
    # if isinstance(out, AbstractMatrix) == False:
    #   out = _symbolic_log_det(A, out)
    return out
  return f_wrapper

################################################################################################################

def _check_cholesky_types(A: AbstractMatrix):
  return A

def _symbolic_cholesky(A: AbstractMatrix, out: AbstractMatrix) -> AbstractMatrix:
  # If A is the identity, then return the identity
  out = util.where(A.tags.is_eye, A, out)

  # If A is zero, then return zero
  out = util.where(A.tags.is_zero, A.set_zero(), out)

  # If A is inf, then return inf
  out = util.where(A.tags.is_inf, A.set_inf(), out)

  return out

def symbolic_cholesky(f):

  @wraps(f)
  def f_wrapper(A: AbstractMatrix) -> AbstractMatrix:

    # Check the types
    A = _check_cholesky_types(A)

    # Perform the operation
    out = f(A)

    # # Perform the symbolic operation
    # if isinstance(out, AbstractMatrix) == False:
    #   out = _symbolic_cholesky(A, out)
    return out
  return f_wrapper

################################################################################################################

def _check_exp_types(A: AbstractMatrix):
  return A

def _symbolic_exp(A: AbstractMatrix, out: AbstractMatrix) -> AbstractMatrix:
  # If A is zero, then return the identity
  out = util.where(A.tags.is_zero, A.set_eye(), out)

  # If A is inf, then return inf
  out = util.where(A.tags.is_inf, A.set_inf(), out)

  return out

def symbolic_exp(f):

  @wraps(f)
  def f_wrapper(A: AbstractMatrix) -> AbstractMatrix:
    # Check the types
    A = _check_exp_types(A)

    # Perform the operation
    out = f(A)

    # # Perform the symbolic operation
    # if isinstance(out, AbstractMatrix) == False:
    #   out = _symbolic_exp(A, out)
    return out
  return f_wrapper

################################################################################################################

def _check_svd_types(A: AbstractMatrix):
  return A

def symbolic_svd(f):

  @wraps(f)
  def f_wrapper(A: AbstractMatrix) -> AbstractMatrix:
    # Check the types
    A = _check_svd_types(A)

    # Perform the operation
    out = f(A)

    # # Perform the symbolic operation
    # if isinstance(out, AbstractMatrix) == False:
    #   out = _symbolic_svd(A, out)
    return out
  return f_wrapper
