import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from diffusion_crf.matrix.matrix_base import *
from diffusion_crf.base import auto_vmap
from plum import dispatch
from diffusion_crf.matrix.dense import DenseMatrix
from diffusion_crf.matrix.symbolic import *
from diffusion_crf.matrix.tags import Tags, TAGS
from diffusion_crf.util.svd import my_svd

class DiagonalMatrix(AbstractMatrix):

  tags: Tags
  elements: Float[Array, 'N'] # The elements of the matrix

  def __init__(
      self,
      elements: Float[Array, 'N'],
      tags: Tags
  ):
    self.elements = elements
    self.tags = tags

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self.elements.ndim > 2:
      return self.elements.shape[:-1]
    elif self.elements.ndim == 2:
      return self.elements.shape[0]
    elif self.elements.ndim == 1:
      return None
    else:
      raise ValueError(f"Invalid number of dimensions: {self.elements.ndim}")

  @property
  def shape(self):
    dim = self.elements.shape[-1]
    return self.elements.shape[:-1] + (dim, dim)

  @classmethod
  def zeros(cls, dim: int) -> 'DiagonalMatrix':
    return DiagonalMatrix(jnp.zeros(dim), tags=TAGS.zero_tags)

  @classmethod
  def eye(cls, dim: int) -> 'DiagonalMatrix':
    return DiagonalMatrix(jnp.ones(dim), tags=TAGS.eye_tags)

  @auto_vmap
  def as_matrix(self) -> Float[Array, "N"]:
    return jnp.diag(self.elements)

  @auto_vmap
  def __neg__(self) -> 'DiagonalMatrix':
    return DiagonalMatrix(-self.elements, tags=self.tags)

  @auto_vmap
  def to_dense(self) -> DenseMatrix:
    return DenseMatrix(jnp.diag(self.elements), tags=self.tags)

  @auto_vmap
  def project_dense(self, dense: 'DenseMatrix') -> 'DenseMatrix':
    diag_elements = jnp.diag(dense.elements)
    return DiagonalMatrix(diag_elements, tags=dense.tags)

class ParametricSymmetricDiagonalMatrix(DiagonalMatrix):

  tags: Tags
  _elements: Float[Array, 'N']

  def __init__(self, elements: Float[Array, 'N']):
    self._elements = elements
    self.tags = TAGS.symmetric_tags

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self._elements.ndim > 2:
      return self._elements.shape[:-1]
    elif self._elements.ndim == 2:
      return self._elements.shape[0]
    elif self._elements.ndim == 1:
      return None
    else:
      raise ValueError(f"Invalid number of dimensions: {self._elements.ndim}")

  @property
  def elements(self) -> Float[Array, 'N']:
    return jnp.abs(self._elements) + 1e-8

@dispatch
def make_parametric_symmetric_matrix(matrix: DiagonalMatrix) -> ParametricSymmetricDiagonalMatrix:
  return ParametricSymmetricDiagonalMatrix(matrix.get_cholesky().T.elements)

################################################################################################################

# @dispatch
# @symbolic_add
# def mat_add(A: DiagonalMatrix, B: Union[Scalar, float]) -> DiagonalMatrix:
#   return DiagonalMatrix(A.elements + B, tags=A.tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: DiagonalMatrix, B: DiagonalMatrix) -> DiagonalMatrix:
  new_tags = A.tags.add_update(B.tags)
  return DiagonalMatrix(A.elements + B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: DiagonalMatrix, B: DenseMatrix) -> DenseMatrix:
  new_tags = A.tags.add_update(B.tags)
  return DenseMatrix(A.as_matrix() + B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: DenseMatrix, B: DiagonalMatrix) -> DenseMatrix:
  return mat_add(B, A)

################################################################################################################

@dispatch
@symbolic_scalar_mul
def scalar_mul(A: DiagonalMatrix, s: Scalar) -> DiagonalMatrix:
  new_tags = A.tags.scalar_mul_update()
  return DiagonalMatrix(s*A.elements, tags=new_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_mat_mul
def mat_mul(A: DiagonalMatrix, B: DiagonalMatrix) -> DiagonalMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  return DiagonalMatrix(A.elements*B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: DiagonalMatrix, B: DenseMatrix) -> DenseMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  return DenseMatrix(A.as_matrix()@B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: DenseMatrix, B: DiagonalMatrix) -> DenseMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  return DenseMatrix(A.elements@B.as_matrix(), tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: DiagonalMatrix, b: Float[Array, 'N']) -> Float[Array, 'M']:
  return A.elements*b

################################################################################################################

@dispatch
def transpose(A: DiagonalMatrix) -> DiagonalMatrix:
  return A

################################################################################################################

@dispatch
@symbolic_solve
def matrix_solve(A: DiagonalMatrix, B: DiagonalMatrix) -> DiagonalMatrix:
  A_elements = A.elements
  out_elements = B.elements/A_elements
  out_tags = A.tags.solve_update(B.tags)
  return DiagonalMatrix(out_elements, tags=out_tags).fix_to_tags()

@dispatch
@symbolic_solve
def matrix_solve(A: DiagonalMatrix, B: DenseMatrix) -> DenseMatrix:
  A_elements = A.elements
  out_elements = B.elements/A_elements[...,None,:]
  out_tags = A.tags.solve_update(B.tags)
  return DenseMatrix(out_elements, tags=out_tags).fix_to_tags()

@dispatch
@symbolic_solve
def matrix_solve(A: DenseMatrix, B: DiagonalMatrix) -> DenseMatrix:
  return matrix_solve(A, B.to_dense())

@dispatch
@symbolic_solve
def matrix_solve(A: DiagonalMatrix, b: Float[Array, 'N']) -> Float[Array, 'M']:
  A_elements = A.elements
  return b/A_elements

################################################################################################################

@dispatch
@symbolic_matrix_inverse
def get_matrix_inverse(A: DiagonalMatrix) -> DiagonalMatrix:
  out_elements = 1/A.elements
  out_tags = A.tags.inverse_update()
  return DiagonalMatrix(out_elements, tags=out_tags).fix_to_tags()

@dispatch
@symbolic_log_det
def get_log_det(A: DiagonalMatrix) -> Scalar:
  return jnp.sum(jnp.log(jnp.abs(A.elements)))

@dispatch
@symbolic_cholesky
def get_cholesky(A: DiagonalMatrix) -> DiagonalMatrix:
  out_elements = jnp.sqrt(A.elements)
  out_tags = A.tags.cholesky_update()
  return DiagonalMatrix(out_elements, tags=out_tags).fix_to_tags()

@dispatch
@symbolic_exp
def get_exp(A: DiagonalMatrix) -> DiagonalMatrix:
  expA = jnp.exp(A.elements)
  out_tags = A.tags.exp_update()
  return DiagonalMatrix(expA, tags=out_tags).fix_to_tags()

@dispatch
@symbolic_svd
def get_svd(A: DiagonalMatrix) -> Tuple[DiagonalMatrix, 'DiagonalMatrix', DiagonalMatrix]:
  A_elts = jnp.abs(A.elements)
  S = DiagonalMatrix(A_elts, tags=TAGS.no_tags).fix_to_tags()
  A_signs = A_elts/jnp.abs(A_elts)
  U = A.set_eye()
  V = DiagonalMatrix(A_signs, tags=TAGS.no_tags).fix_to_tags()
  return U, S, V

################################################################################################################
def correctness_tests(key):
  from diffusion_crf.matrix.matrix_base import matrix_tests
  from itertools import product

  # Turn on x64
  jax.config.update('jax_enable_x64', True)
  key = random.PRNGKey(0)

  # All available tags
  tag_options = [
    TAGS.symmetric_tags,
    TAGS.zero_tags,
    TAGS.eye_tags,
    TAGS.no_tags,
    TAGS.inf_tags
  ]

  # Test all combinations of tags
  failed_tests = []
  total_tests = 0

  for tag_A, tag_B in product(tag_options, tag_options):
    total_tests += 1
    print(f"\nTesting combination:")
    print(f"A tags: {tag_A}")
    print(f"B tags: {tag_B}")

    k1, k2 = random.split(key)
    key, _ = random.split(key)

    # Generate base random diagonal elements
    A_raw = random.normal(k1, (4,))
    B_raw = random.normal(k2, (4,))

    # Modify diagonals according to tags
    if tag_A.is_zero:
      A_raw = jnp.zeros_like(A_raw)
    if tag_A.is_eye:
      A_raw = jnp.ones_like(A_raw)
    if tag_A.is_inf:
      A_raw = jnp.full_like(A_raw, jnp.inf)

    if tag_B.is_zero:
      B_raw = jnp.zeros_like(B_raw)
    if tag_B.is_eye:
      B_raw = jnp.ones_like(B_raw)
    if tag_B.is_inf:
      B_raw = jnp.full_like(B_raw, jnp.inf)

    A = DiagonalMatrix(A_raw, tags=tag_A)
    B = DiagonalMatrix(B_raw, tags=tag_B)

    try:
      matrix_tests(key, A, B)
      print("✓ Tests passed")
    except Exception as e:
      print(f"✗ Tests failed with error: {str(e)}")
      failed_tests.append({
        'A_tags': tag_A,
        'B_tags': tag_B,
        'error': str(e)
      })

  # Print summary
  print("\n" + "="*80)
  print(f"Test Summary: {total_tests - len(failed_tests)}/{total_tests} tests passed")

  if failed_tests:
    print("\nFailed Test Combinations:")
    for i, test in enumerate(failed_tests, 1):
      print(f"\n{i}. Failed combination:")
      print(f"   A tags: {test['A_tags']}")
      print(f"   B tags: {test['B_tags']}")
      print(f"   Error: {test['error']}")

def performance_tests(key):
  from diffusion_crf.matrix.matrix_base import performance_tests
  from latent_linear_sde import DiagonalEagerLinearOperator
  key1, key2 = random.split(key)
  A_raw = random.normal(key1, (4,))
  B_raw = random.normal(key2, (4,))

  # Test DiagonalMatrix with profiler
  with jax.profiler.trace("/tmp/tensorboard"):
    A = DiagonalMatrix(A_raw, tags=TAGS.no_tags)
    B = DiagonalMatrix(B_raw, tags=TAGS.no_tags)
    out = performance_tests(A, B)

  # Test MatrixEagerLinearOperator with profiler
  with jax.profiler.trace("/tmp/tensorboard"):
    A2 = DiagonalEagerLinearOperator(A_raw)
    B2 = DiagonalEagerLinearOperator(B_raw)
    out = performance_tests(A2, B2)


def vjp_test(key):
  a, h = random.normal(key, (2, 4))
  A = DiagonalMatrix(a, tags=TAGS.symmetric_tags)
  from diffusion_crf.gaussian.dist import NaturalGaussian

  # def to_natural_gaussian(A, h):
  #   return NaturalGaussian(A, h)
  # nat_gaussian, vjp1 = jax.vjp(to_natural_gaussian, A, h)


  def to_natural_gaussian(h):
    return NaturalGaussian(A, h)
  nat_gaussian, vjp1 = jax.vjp(to_natural_gaussian, h)


  ones = jtu.tree_map(lambda x: jnp.ones_like(x), nat_gaussian)
  dx2, = vjp1(ones)
  import pdb; pdb.set_trace()





if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  key = random.PRNGKey(0)
  vjp_test(key)
  correctness_tests(key)
  # performance_tests(key)
