import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from diffusion_crf.matrix.matrix_base import *
from plum import dispatch, ModuleType
from diffusion_crf.matrix.symbolic import *
from diffusion_crf.matrix.tags import Tags, TAGS
from diffusion_crf.util.svd import my_svd

class DenseMatrix(AbstractMatrix):

  tags: Tags
  elements: Float[Array, 'M N'] # The elements of the matrix

  def __init__(
      self,
      elements: Float[Array, 'M N'],
      tags: Tags
  ):
    self.elements = elements
    self.tags = tags

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self.elements.ndim > 3:
      return self.elements.shape[:-2]
    elif self.elements.ndim == 3:
      return self.elements.shape[0]
    elif self.elements.ndim == 2:
      return None
    else:
      raise ValueError(f"Invalid number of dimensions: {self.elements.ndim}")

  @property
  def shape(self):
    return self.elements.shape

  @classmethod
  def zeros(cls, shape: Tuple[int, ...]) -> 'DenseMatrix':
    return DenseMatrix(jnp.zeros(shape), tags=TAGS.zero_tags)

  @classmethod
  def eye(cls, dim: int) -> 'DenseMatrix':
    return DenseMatrix(jnp.eye(dim), tags=TAGS.eye_tags)

  def as_matrix(self) -> Float[Array, "M N"]:
    return self.elements

  def __neg__(self) -> 'DenseMatrix':
    return DenseMatrix(-self.elements, tags=self.tags)

  def project_dense(self, dense: 'DenseMatrix') -> 'DenseMatrix':
    return dense

class ParametricSymmetricDenseMatrix(DenseMatrix):

  tags: Tags
  _elements: Float[Array, 'N N']

  def __init__(self, elements: Float[Array, 'N N']):
    self._elements = elements
    self.tags = TAGS.symmetric_tags

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self._elements.ndim > 3:
      return self._elements.shape[:-2]
    elif self._elements.ndim == 3:
      return self._elements.shape[0]
    elif self._elements.ndim == 2:
      return None
    else:
      raise ValueError(f"Invalid number of dimensions: {self._elements.ndim}")

  @property
  def elements(self) -> Float[Array, 'N N']:
    # Make the diagonal elements of self._elements positive
    diag_idx = jnp.arange(self._elements.shape[-1])
    diag_elements = self._elements[diag_idx, diag_idx]
    _elements = self._elements.at[diag_idx, diag_idx].set(jnp.abs(diag_elements) + 1e-8)

    # Also make sure that the matrix is upper triangular
    _elements = jnp.triu(_elements)

    # Return _elements.T@_elements
    return jnp.einsum('...ji,...jk->...ik', _elements, _elements)

@dispatch
def make_parametric_symmetric_matrix(matrix: DenseMatrix) -> ParametricSymmetricDenseMatrix:
  return ParametricSymmetricDenseMatrix(matrix.get_cholesky().T.elements)

################################################################################################################

# @dispatch
# @symbolic_add
# def mat_add(A: DenseMatrix, B: Union[Scalar, float]) -> DenseMatrix:
#   return DenseMatrix(A.elements + B, tags=A.tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: DenseMatrix, B: DenseMatrix) -> DenseMatrix:
  new_tags = A.tags.add_update(B.tags)
  return DenseMatrix(A.elements + B.elements, tags=new_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_scalar_mul
def scalar_mul(A: DenseMatrix, s: Scalar) -> DenseMatrix:
  new_tags = A.tags.scalar_mul_update()
  return DenseMatrix(s*A.elements, tags=new_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_mat_mul
def mat_mul(A: DenseMatrix, B: DenseMatrix) -> DenseMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  return DenseMatrix(A.elements@B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: DenseMatrix, b: Float[Array, 'N']) -> Float[Array, 'M']:
  return A.elements@b

################################################################################################################

@dispatch
def transpose(A: DenseMatrix) -> DenseMatrix:
  return DenseMatrix(A.elements.swapaxes(-1, -2), tags=A.tags)

  # To avoid needing a new tag, we can just always swap the axes
  # if A.shape[-1] == A.shape[-2]:
  #   A_elements = jnp.where(A.tags.is_symmetric, A.elements, A.elements.T)
  # else:
  #   A_elements = A.elements.swapaxes(-1, -2)
  # return DenseMatrix(A_elements, tags=A.tags)
  # new_tags = A.tags.transpose_update()
  # return DenseMatrix(A_elements, tags=new_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_solve
def matrix_solve(A: DenseMatrix, B: DenseMatrix) -> DenseMatrix:
  A_elements = A.elements
  out_elements = jnp.linalg.solve(A_elements, B.elements)
  out_tags = A.tags.solve_update(B.tags)
  return DenseMatrix(out_elements, tags=out_tags).fix_to_tags()

@dispatch
@symbolic_solve
def matrix_solve(A: DenseMatrix, b: Float[Array, 'N']) -> Float[Array, 'M']:
  return jnp.linalg.solve(A.elements, b)

################################################################################################################

@dispatch
@symbolic_matrix_inverse
def get_matrix_inverse(A: DenseMatrix) -> DenseMatrix:
  out_elements = jnp.linalg.inv(A.elements)
  out_tags = A.tags.inverse_update()
  return DenseMatrix(out_elements, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_log_det
def get_log_det(A: DenseMatrix) -> Scalar:
  return jnp.linalg.slogdet(A.elements)[1]

################################################################################################################

@dispatch
@symbolic_cholesky
def get_cholesky(A: DenseMatrix) -> DenseMatrix:
  chol = jnp.linalg.cholesky(A.elements)
  out_tags = A.tags.cholesky_update()
  return DenseMatrix(chol, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_exp
def get_exp(A: DenseMatrix) -> DenseMatrix:
  expA = jax.scipy.linalg.expm(A.elements)
  out_tags = A.tags.exp_update()
  return DenseMatrix(expA, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_svd
def get_svd(A: DenseMatrix) -> Tuple[DenseMatrix, Any, DenseMatrix]:
  from diffusion_crf.matrix.diagonal import DiagonalMatrix
  U_elts, s_elts, V_elts = my_svd(A.elements)
  U = DenseMatrix(U_elts, tags=TAGS.no_tags).fix_to_tags()
  s = DiagonalMatrix(s_elts, tags=TAGS.no_tags).fix_to_tags()
  V = DenseMatrix(V_elts, tags=TAGS.no_tags).fix_to_tags()
  return U, s, V

################################################################################################################

def correctness_tests(key):

  from diffusion_crf.matrix.matrix_base import matrix_tests, performance_tests
  from itertools import product

  # Turn on x64
  jax.config.update('jax_enable_x64', True)
  key = random.PRNGKey(0)

  # All available tags
  tag_options = [
    TAGS.symmetric_tags,
    TAGS.zero_tags,
    TAGS.eye_tags,
    TAGS.no_tags,
    TAGS.inf_tags
  ]

  # Test all combinations of tags
  failed_tests = []
  total_tests = 0

  for tag_A, tag_B in product(tag_options, tag_options):
    total_tests += 1
    print(f"\nTesting combination:")
    print(f"A tags: {tag_A}")
    print(f"B tags: {tag_B}")

    k1, k2 = random.split(key)
    key, _ = random.split(key)

    # Generate base random matrices
    A_raw = random.normal(k1, (4, 4))
    B_raw = random.normal(k2, (4, 4))

    # Modify matrices according to tags
    if tag_A.is_symmetric:
      A_raw = A_raw @ A_raw.T
    if tag_A.is_zero:
      A_raw = jnp.zeros_like(A_raw)
    if tag_A.is_eye:
      A_raw = jnp.eye(4)
    if tag_A.is_inf:
      A_raw = jnp.full_like(A_raw, jnp.inf)

    if tag_B.is_symmetric:
      B_raw = B_raw @ B_raw.T
    if tag_B.is_zero:
      B_raw = jnp.zeros_like(B_raw)
    if tag_B.is_eye:
      B_raw = jnp.eye(4)
    if tag_B.is_inf:
      B_raw = jnp.full_like(B_raw, jnp.inf)

    A = DenseMatrix(A_raw, tags=tag_A)
    B = DenseMatrix(B_raw, tags=tag_B)

    try:
      matrix_tests(key, A, B)
      print("✓ Tests passed")
    except Exception as e:
      print(f"✗ Tests failed with error: {str(e)}")
      failed_tests.append({
        'A_tags': tag_A,
        'B_tags': tag_B,
        'error': str(e)
      })

  # Print summary
  print("\n" + "="*80)
  print(f"Test Summary: {total_tests - len(failed_tests)}/{total_tests} tests passed")

  if failed_tests:
    print("\nFailed Test Combinations:")
    for i, test in enumerate(failed_tests, 1):
      print(f"\n{i}. Failed combination:")
      print(f"   A tags: {test['A_tags']}")
      print(f"   B tags: {test['B_tags']}")
      print(f"   Error: {test['error']}")

def performance_tests(key):
  from diffusion_crf.matrix.matrix_base import performance_tests
  from latent_linear_sde import MatrixEagerLinearOperator
  key1, key2 = random.split(key)
  A_raw = random.normal(key1, (4, 4))
  B_raw = random.normal(key2, (4, 4))

  # Test DenseMatrix with profiler
  with jax.profiler.trace("/tmp/tensorboard"):
    A = DenseMatrix(A_raw, tags=TAGS.no_tags)
    B = DenseMatrix(B_raw, tags=TAGS.no_tags)
    out = performance_tests(A, B)

  # Test MatrixEagerLinearOperator with profiler
  with jax.profiler.trace("/tmp/tensorboard"):
    A2 = MatrixEagerLinearOperator(A_raw)
    B2 = MatrixEagerLinearOperator(B_raw)
    out = performance_tests(A2, B2)


if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  key = random.PRNGKey(0)
  correctness_tests(key)
  # performance_tests(key)
