import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from plum import dispatch
import numpy as np

__all__ = ['Tags', 'TAGS']

switch = True

if switch:
  class Tags(eqx.Module):
    """Contains different properties of a matrix.  Knowing these can facilitate more efficient code"""
    is_nonzero: Bool # Use non-zero so that creating a zero matrix can be done with jtu.tree_map(jnp.zeros_like, ...)
    is_inf: Bool

    @property
    def is_symmetric(self):
      # Don't keep track of these for performance gains
      return jnp.array(True)

    @property
    def is_eye(self):
      # Don't keep track of these for performance gains
      return jnp.array(False)

    def __str__(self):
      return f"Tags(is_nonzero={self.is_nonzero}, is_inf={self.is_inf})"

    @property
    def is_zero(self):
      return ~self.is_nonzero

    def add_update(self, update: 'Tags') -> 'Tags':
      """
      Addition (A + B)
         B
         |     =0        ≠0        =∞        ≠∞
      A  |-------------------------------------------------
      =0 | (=0, ≠∞)   (≠0, ≠∞)   (≠0, =∞)   (≠0, ≠∞)
      ≠0 |  .         (≠0, ≠∞)   (≠0, =∞)   (≠0, ≠∞)
      =∞ |  .          .         (≠0, =∞)   (≠0, =∞)
      ≠∞ |  .          .          .         (≠0, ≠∞)
      """
      is_nonzero_after_add = self.is_nonzero | update.is_nonzero
      is_inf_after_add = self.is_inf | update.is_inf
      return Tags(is_nonzero_after_add, is_inf_after_add)

    def mat_mul_update(self, update: 'Tags') -> 'Tags':
      """
      Multiplication (A@B)
        B
        |     =0        ≠0        =∞        ≠∞
      A |-------------------------------------------------
      =0 | (=0, ≠∞)   (=0, ≠∞)    ?         (=0, ≠∞)
      ≠0 | (=0, ≠∞)   (≠0, ≠∞)   (≠0, =∞)   (≠0, ≠∞)
      =∞ |  ?         (≠0, =∞)   (≠0, =∞)   (≠0, =∞)
      ≠∞ | (=0, ≠∞)   (≠0, ≠∞)   (≠0, =∞)   (≠0, ≠∞)
      """
      is_nonzero_after_mul = self.is_nonzero & update.is_nonzero
      is_inf_after_mul = self.is_inf | update.is_inf
      return Tags(is_nonzero_after_mul, is_inf_after_mul)

    def scalar_mul_update(self) -> 'Tags':
      return self
      # is_nonzero_after_scalar_mul = self.is_nonzero
      # is_inf_after_scalar_mul = self.is_inf
      # return Tags(is_nonzero_after_scalar_mul, is_inf_after_scalar_mul)

    def transpose_update(self) -> 'Tags':
      return self

    def solve_update(self, update: 'Tags') -> 'Tags':
      """Solve (A⁻¹ * B)
         B
         |     =0        ≠0        =∞        ≠∞
      A  |-------------------------------------------------
      =0 | ?          (≠0, =∞)   (≠0, =∞)   (≠0, =∞)
      ≠0 | (=0, ≠∞)   (≠0, ≠∞)   (≠0, =∞)   (≠0, ≠∞)
      =∞ | (=0, ≠∞)   (=0, ≠∞)    ?         (=0, ≠∞)
      ≠∞ | (=0, ≠∞)   (≠0, ≠∞)   (≠0, =∞)   (≠0, ≠∞)
      """
      # When A is zero (non-invertible), result should be infinite for nonzero B
      zero_case_inf = self.is_zero & update.is_nonzero

      is_nonzero_after_solve = ~self.is_inf & update.is_nonzero
      is_inf_after_solve = (update.is_inf & ~self.is_inf) | zero_case_inf

      return Tags(is_nonzero_after_solve, is_inf_after_solve)

    def inverse_update(self) -> 'Tags':
      # is_zero_after_invert = self.is_inf
      is_nonzero_after_invert = ~self.is_inf
      is_inf_after_invert = self.is_zero
      return Tags(is_nonzero_after_invert, is_inf_after_invert)

    def cholesky_update(self) -> 'Tags':
      is_nonzero_after_cholesky = self.is_nonzero
      is_inf_after_cholesky = self.is_inf
      return Tags(is_nonzero_after_cholesky, is_inf_after_cholesky)

    def exp_update(self) -> 'Tags':
      is_nonzero_after_exp = jnp.ones_like(self.is_nonzero)
      is_inf_after_exp = self.is_inf
      return Tags(is_nonzero_after_exp, is_inf_after_exp)

  class TAGS:
    symmetric_tags = Tags(is_nonzero=np.array(True), is_inf=np.array(False))
    zero_tags = Tags(is_nonzero=np.array(False), is_inf=np.array(False))
    eye_tags = Tags(is_nonzero=np.array(True), is_inf=np.array(False))
    no_tags = Tags(is_nonzero=np.array(True), is_inf=np.array(False))
    lower_triangular_tags = no_tags

    # Set is_symmetric to True so that an inf covariance can be propagated through the algorithm
    inf_tags = Tags(is_nonzero=np.array(True), is_inf=np.array(True))

################################################################################################################

if not switch:
  class Tags(eqx.Module):
    """Contains different properties of a matrix.  Knowing these can facilitate more efficient code"""
    is_nonzero: Bool # Use non-zero so that creating a zero matrix can be done with jtu.tree_map(jnp.zeros_like, ...)
    is_symmetric: Bool
    is_eye: Bool
    is_inf: Bool

    def __str__(self):
      return f"Tags(is_nonzero={self.is_nonzero}, is_symmetric={self.is_symmetric}, is_eye={self.is_eye}, is_inf={self.is_inf})"

    @property
    def is_zero(self):
      return ~self.is_nonzero

    def add_update(self, update: 'Tags') -> 'Tags':
      is_nonzero_after_add = self.is_nonzero | update.is_nonzero
      is_symmetric_after_add = self.is_symmetric & update.is_symmetric
      is_eye_after_add = (self.is_eye & update.is_zero) | (update.is_eye & self.is_zero)
      is_inf_after_add = self.is_inf | update.is_inf
      return Tags(is_nonzero_after_add, is_symmetric_after_add, is_eye_after_add, is_inf_after_add)

    def mat_mul_update(self, update: 'Tags') -> 'Tags':
      is_nonzero_after_mul = self.is_nonzero & update.is_nonzero
      is_symmetric_after_mul = jnp.zeros_like(self.is_symmetric)
      is_eye_after_mul = self.is_eye & update.is_eye
      is_inf_after_mul = self.is_inf | update.is_inf
      return Tags(is_nonzero_after_mul, is_symmetric_after_mul, is_eye_after_mul, is_inf_after_mul)

    def scalar_mul_update(self) -> 'Tags':
      is_nonzero_after_scalar_mul = self.is_nonzero
      is_symmetric_after_scalar_mul = self.is_symmetric
      is_eye_after_scalar_mul = jnp.zeros_like(self.is_eye) # Removes the eye tag
      is_inf_after_scalar_mul = self.is_inf
      return Tags(is_nonzero_after_scalar_mul, is_symmetric_after_scalar_mul, is_eye_after_scalar_mul, is_inf_after_scalar_mul)

    def transpose_update(self) -> 'Tags':
      return self

    def solve_update(self, update: 'Tags') -> 'Tags':
      is_nonzero_after_solve = ~self.is_inf & update.is_nonzero
      is_symmetric_after_solve = jnp.zeros_like(self.is_symmetric)
      is_eye_after_solve = self.is_eye & update.is_eye
      is_inf_after_solve = update.is_inf & ~self.is_inf
      return Tags(is_nonzero_after_solve, is_symmetric_after_solve, is_eye_after_solve, is_inf_after_solve)

    def inverse_update(self) -> 'Tags':
      # is_zero_after_invert = self.is_inf
      is_nonzero_after_invert = ~self.is_inf
      is_symmetric_after_invert = self.is_symmetric
      is_eye_after_invert = self.is_eye
      is_inf_after_invert = self.is_zero
      return Tags(is_nonzero_after_invert, is_symmetric_after_invert, is_eye_after_invert, is_inf_after_invert)

    def cholesky_update(self) -> 'Tags':
      is_nonzero_after_cholesky = self.is_nonzero
      is_symmetric_after_cholesky = jnp.zeros_like(self.is_symmetric) # Not symmetric after cholesky
      is_eye_after_cholesky = self.is_eye
      is_inf_after_cholesky = self.is_inf
      return Tags(is_nonzero_after_cholesky, is_symmetric_after_cholesky, is_eye_after_cholesky, is_inf_after_cholesky)

    def exp_update(self) -> 'Tags':
      is_nonzero_after_exp = jnp.ones_like(self.is_nonzero)
      is_symmetric_after_exp = jnp.zeros_like(self.is_symmetric) # Not symmetric after exp
      is_eye_after_exp = self.is_zero
      is_inf_after_exp = self.is_inf
      return Tags(is_nonzero_after_exp, is_symmetric_after_exp, is_eye_after_exp, is_inf_after_exp)

  class TAGS:
    symmetric_tags = Tags(is_nonzero=jnp.array(True), is_symmetric=jnp.array(True), is_eye=jnp.array(False), is_inf=jnp.array(False))
    zero_tags = Tags(is_nonzero=jnp.array(False), is_symmetric=jnp.array(True), is_eye=jnp.array(False), is_inf=jnp.array(False))
    eye_tags = Tags(is_nonzero=jnp.array(True), is_symmetric=jnp.array(True), is_eye=jnp.array(True), is_inf=jnp.array(False))
    no_tags = Tags(is_nonzero=jnp.array(True), is_symmetric=jnp.array(False), is_eye=jnp.array(False), is_inf=jnp.array(False))
    lower_triangular_tags = no_tags

    # Set is_symmetric to True so that an inf covariance can be propagated through the algorithm
    inf_tags = Tags(is_nonzero=jnp.array(True), is_symmetric=jnp.array(True), is_eye=jnp.array(False), is_inf=jnp.array(True))



################################################################################################################
################################################################################################################
################################################################################################################
################################################################################################################



def create_tag(is_zero: bool, is_inf: bool) -> Tags:
  return Tags(~is_zero, is_inf)

def assert_tags_equal(tag1: Tags, tag2: Tags):
  assert tag1.is_zero == tag2.is_zero, f"Zero mismatch: {tag1} vs {tag2}"
  assert tag1.is_inf == tag2.is_inf, f"Inf mismatch: {tag1} vs {tag2}"

def test_add_update():
  # Create all combinations
  zero = create_tag(is_zero=True, is_inf=False)
  nonzero = create_tag(is_zero=False, is_inf=False)
  inf = create_tag(is_zero=False, is_inf=True)
  noninf = create_tag(is_zero=False, is_inf=False)

  # Test cases from the table
  test_cases = [
    # (A, B, expected_result)
    (zero, zero, create_tag(is_zero=True, is_inf=False)),
    (zero, nonzero, create_tag(is_zero=False, is_inf=False)),
    (zero, inf, create_tag(is_zero=False, is_inf=True)),
    (zero, noninf, create_tag(is_zero=False, is_inf=False)),

    (nonzero, nonzero, create_tag(is_zero=False, is_inf=False)),
    (nonzero, inf, create_tag(is_zero=False, is_inf=True)),
    (nonzero, noninf, create_tag(is_zero=False, is_inf=False)),

    (inf, inf, create_tag(is_zero=False, is_inf=True)),
    (inf, noninf, create_tag(is_zero=False, is_inf=True)),

    (noninf, noninf, create_tag(is_zero=False, is_inf=False)),
  ]

  for A, B, expected in test_cases:
    result = A.add_update(B)
    assert_tags_equal(result, expected)

  print("All add_update tests passed!")

def test_mat_mul_update():
  # Create all combinations
  zero = create_tag(is_zero=True, is_inf=False)
  nonzero = create_tag(is_zero=False, is_inf=False)
  inf = create_tag(is_zero=False, is_inf=True)
  noninf = create_tag(is_zero=False, is_inf=False)

  test_cases = [
    # (A, B, expected_result)
    (zero, zero, create_tag(is_zero=True, is_inf=False)),
    (zero, nonzero, create_tag(is_zero=True, is_inf=False)),
    # (zero, inf, None),  # Undefined case
    (zero, noninf, create_tag(is_zero=True, is_inf=False)),

    (nonzero, zero, create_tag(is_zero=True, is_inf=False)),
    (nonzero, nonzero, create_tag(is_zero=False, is_inf=False)),
    (nonzero, inf, create_tag(is_zero=False, is_inf=True)),
    (nonzero, noninf, create_tag(is_zero=False, is_inf=False)),

    # (inf, zero, None),  # Undefined case
    (inf, nonzero, create_tag(is_zero=False, is_inf=True)),
    (inf, inf, create_tag(is_zero=False, is_inf=True)),
    (inf, noninf, create_tag(is_zero=False, is_inf=True)),

    (noninf, zero, create_tag(is_zero=True, is_inf=False)),
    (noninf, nonzero, create_tag(is_zero=False, is_inf=False)),
    (noninf, inf, create_tag(is_zero=False, is_inf=True)),
    (noninf, noninf, create_tag(is_zero=False, is_inf=False)),
  ]

  for A, B, expected in test_cases:
    result = A.mat_mul_update(B)
    assert_tags_equal(result, expected)

  print("All mat_mul_update tests passed!")

def test_solve_update():
  # Create all combinations
  zero = create_tag(is_zero=True, is_inf=False)
  nonzero = create_tag(is_zero=False, is_inf=False)
  inf = create_tag(is_zero=False, is_inf=True)
  noninf = create_tag(is_zero=False, is_inf=False)

  test_cases = [
    # (A, B, expected_result)
    # (zero, zero, None),  # Undefined case
    (zero, nonzero, create_tag(is_zero=False, is_inf=True)),
    (zero, inf, create_tag(is_zero=False, is_inf=True)),
    (zero, noninf, create_tag(is_zero=False, is_inf=True)),

    (nonzero, zero, create_tag(is_zero=True, is_inf=False)),
    (nonzero, nonzero, create_tag(is_zero=False, is_inf=False)),
    (nonzero, inf, create_tag(is_zero=False, is_inf=True)),
    (nonzero, noninf, create_tag(is_zero=False, is_inf=False)),

    (inf, zero, create_tag(is_zero=True, is_inf=False)),
    (inf, nonzero, create_tag(is_zero=True, is_inf=False)),
    # (inf, inf, None),  # Undefined case
    (inf, noninf, create_tag(is_zero=True, is_inf=False)),

    (noninf, zero, create_tag(is_zero=True, is_inf=False)),
    (noninf, nonzero, create_tag(is_zero=False, is_inf=False)),
    (noninf, inf, create_tag(is_zero=False, is_inf=True)),
    (noninf, noninf, create_tag(is_zero=False, is_inf=False)),
  ]

  for A, B, expected in test_cases:
    result = A.solve_update(B)
    assert_tags_equal(result, expected)

  print("All solve_update tests passed!")

if __name__ == "__main__":
  test_add_update()
  test_mat_mul_update()
  test_solve_update()
  print("All tests passed!")