import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from diffusion_crf.matrix.matrix_base import *
from diffusion_crf.matrix.dense import *
from diffusion_crf.matrix.diagonal import *
from diffusion_crf.base import auto_vmap
from plum import dispatch

class MatrixWithInverse(AbstractMatrix):

  matrix: AbstractMatrix
  inverse_matrix: AbstractMatrix

  def __init__(self, matrix: AbstractMatrix, inverse_matrix: AbstractMatrix):
    if matrix.shape != inverse_matrix.shape:
      raise ValueError(f"Matrix and inverse matrix must have the same shape, got {matrix.shape} and {inverse_matrix.shape}")
    self.matrix = matrix
    self.inverse_matrix = inverse_matrix

  @property
  def tags(self):
    return self.matrix.tags

  def set_eye(self) -> 'AbstractMatrix':
    new_mat = self.matrix.set_eye()
    new_inv = self.inverse_matrix.set_eye()
    return MatrixWithInverse(new_mat, new_inv)

  def set_symmetric(self) -> 'AbstractMatrix':
    new_mat = self.matrix.set_symmetric()
    new_inv = self.inverse_matrix.set_symmetric()
    return MatrixWithInverse(new_mat, new_inv)

  def set_zero(self) -> 'AbstractMatrix':
    new_mat = self.matrix.set_zero()
    new_inv = self.inverse_matrix.set_inf()
    return MatrixWithInverse(new_mat, new_inv)

  def set_inf(self) -> 'AbstractMatrix':
    new_mat = self.matrix.set_inf()
    new_inv = self.inverse_matrix.set_zero()
    return MatrixWithInverse(new_mat, new_inv)

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.matrix.batch_size

  @property
  def shape(self):
    return self.matrix.shape

  @classmethod
  def zeros(cls, shape: Tuple[int, ...]) -> 'MatrixWithInverse':
    matrix = AbstractMatrix.zeros(shape)
    inverse_matrix = AbstractMatrix.zeros(shape)
    return MatrixWithInverse(matrix, inverse_matrix)

  @classmethod
  def eye(cls, dim: int) -> 'MatrixWithInverse':
    matrix = AbstractMatrix.eye(dim)
    inverse_matrix = AbstractMatrix.eye(dim)
    return MatrixWithInverse(matrix, inverse_matrix)

  @auto_vmap
  def as_matrix(self):
    return self.matrix.as_matrix()

  @auto_vmap
  def __neg__(self) -> 'MatrixWithInverse':
    return MatrixWithInverse(-self.matrix, -self.inverse_matrix)

  @auto_vmap
  def to_dense(self) -> DenseMatrix:
    return DenseMatrix(self.as_matrix(), tags=self.tags)

################################################################################################################

@dispatch
def mat_add(A: MatrixWithInverse, B: MatrixWithInverse) -> AbstractMatrix:
  return mat_add(A.matrix, B.matrix)

@dispatch
def mat_add(A: MatrixWithInverse, B: AbstractMatrix) -> AbstractMatrix:
  return mat_add(A.matrix, B)

@dispatch
def mat_add(A: AbstractMatrix, B: MatrixWithInverse) -> AbstractMatrix:
  return mat_add(A, B.matrix)

################################################################################################################

@dispatch
def scalar_mul(A: MatrixWithInverse, s: Scalar) -> MatrixWithInverse:
  new_tags = A.tags.scalar_mul_update()
  return MatrixWithInverse(s*A.matrix, 1/s*A.inverse_matrix)

################################################################################################################

@dispatch
def mat_mul(A: MatrixWithInverse, b: Float[Array, 'N']) -> Float[Array, 'M']:
  return A.matrix@b

@dispatch
def mat_mul(A: MatrixWithInverse, B: MatrixWithInverse) -> MatrixWithInverse:
  return MatrixWithInverse(mat_mul(A.matrix, B.matrix), mat_mul(B.inverse_matrix, A.inverse_matrix))

@dispatch
def mat_mul(A: MatrixWithInverse, B: AbstractMatrix) -> AbstractMatrix:
  return mat_mul(A.matrix, B)

@dispatch
def mat_mul(A: AbstractMatrix, B: MatrixWithInverse) -> AbstractMatrix:
  return mat_mul(A, B.matrix)

################################################################################################################

@dispatch
def transpose(A: MatrixWithInverse) -> MatrixWithInverse:
  return MatrixWithInverse(transpose(A.matrix), transpose(A.inverse_matrix))

################################################################################################################

@dispatch
def matrix_solve(A: MatrixWithInverse, B: MatrixWithInverse) -> MatrixWithInverse:
  sol = mat_mul(A.inverse_matrix, B.matrix)
  sol_inv = mat_mul(B.inverse_matrix, A.inverse_matrix)
  return MatrixWithInverse(sol, sol_inv)

@dispatch
def matrix_solve(A: MatrixWithInverse, B: AbstractMatrix) -> AbstractMatrix:
  return mat_mul(A.inverse_matrix, B)

@dispatch
def matrix_solve(A: AbstractMatrix, B: MatrixWithInverse) -> AbstractMatrix:
  return matrix_solve(A, B.matrix)

################################################################################################################

@dispatch
def get_matrix_inverse(A: MatrixWithInverse) -> MatrixWithInverse:
  return MatrixWithInverse(A.inverse_matrix, A.matrix)

################################################################################################################

@dispatch
def get_log_det(A: MatrixWithInverse) -> Scalar:
  return A.matrix.get_log_det()

@dispatch
def get_cholesky(A: MatrixWithInverse) -> AbstractMatrix:
  return A.matrix.get_cholesky()

@dispatch
def get_exp(A: MatrixWithInverse) -> AbstractMatrix:
  return A.matrix.get_exp()

################################################################################################################

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  from diffusion_crf.matrix.matrix_base import matrix_tests
  from diffusion_crf.matrix.diagonal_block_3x3 import Diagonal3x3BlockMatrix

  # Turn on x64
  jax.config.update('jax_enable_x64', True)
  key = random.PRNGKey(0)

  # Dense matrix tests
  k1, k2 = random.split(key)
  A = random.normal(k1, (3, 3, 4))
  B = random.normal(k2, (3, 3, 4))
  A = Diagonal3x3BlockMatrix(A, tags=TAGS.no_tags)
  Ainv = A.get_inverse()
  A_with_inv = MatrixWithInverse(A, Ainv)
  B = Diagonal3x3BlockMatrix(B, tags=TAGS.no_tags)
  Binv = B.get_inverse()
  B_with_inv = MatrixWithInverse(B, Binv)
  matrix_tests(key, A_with_inv, B_with_inv)

  import pdb; pdb.set_trace()