import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from diffusion_crf.matrix.matrix_base import *
from diffusion_crf.base import auto_vmap
from plum import dispatch
from diffusion_crf.matrix.dense import DenseMatrix
from diffusion_crf.matrix.diagonal import DiagonalMatrix
from diffusion_crf.matrix.diagonal_block_2x2 import Diagonal2x2BlockMatrix
from diffusion_crf.matrix.symbolic import *
from diffusion_crf.matrix.tags import Tags, TAGS
from diffusion_crf.util.svd import my_svd

class Diagonal3x3BlockMatrix(AbstractMatrix):

  tags: Tags
  elements: Float[Array, '3 3 N'] # The elements of the matrix

  def __init__(
      self,
      elements: Float[Array, '3 3 N'],
      tags: Tags
  ):
    self.elements = elements
    self.tags = tags

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self.elements.ndim > 4:
      return self.elements.shape[:-3]
    elif self.elements.ndim == 4:
      return self.elements.shape[0]
    elif self.elements.ndim == 3:
      return None
    else:
      raise ValueError(f"Invalid number of dimensions: {self.elements.ndim}")

  @property
  def shape(self):
    batch_shape = self.elements.shape[:-3]
    return batch_shape + (3*self.elements.shape[-1], 3*self.elements.shape[-1])

  @classmethod
  def zeros(cls, shape: Tuple[int, ...]) -> 'Diagonal3x3BlockMatrix':
    dim = shape[-1]
    assert shape[-2] == shape[-1] == dim
    assert dim%3 == 0
    elements = jnp.zeros((3, 3, dim//3))
    return Diagonal3x3BlockMatrix(elements, tags=TAGS.zero_tags)

  @classmethod
  def eye(cls, dim: int) -> 'Diagonal3x3BlockMatrix':
    return cls.from_diagonal(DiagonalMatrix.eye(dim))

  @auto_vmap
  def as_matrix(self):
    # Construct a block diagonal matrix from self.elements that has a height of H*C and width of W*C
    # where each block is a diagonal matrix of size CxC that has the entries self.elements[i,j,:]
    H, W, C = self.elements.shape[-3:]
    if len(self.elements.shape) > 3:
      block_diagonal = jnp.zeros((self.elements.shape[:-3] + (H*C, W*C)))
    else:
      block_diagonal = jnp.zeros((H*C, W*C))

    for i in range(H):
      for j in range(W):
        block_start_row = i * C
        block_start_col = j * C
        block = jnp.diag(self.elements[...,i, j, :])
        block_diagonal = block_diagonal.at[...,block_start_row:block_start_row+C, block_start_col:block_start_col+C].set(block)

    return block_diagonal

  @auto_vmap
  def __neg__(self) -> 'Diagonal3x3BlockMatrix':
    return Diagonal3x3BlockMatrix(-self.elements, tags=self.tags)

  @auto_vmap
  def to_dense(self) -> DenseMatrix:
    return DenseMatrix(self.as_matrix(), tags=self.tags)

  @classmethod
  def from_diagonal(cls, diagonal: DiagonalMatrix) -> 'Diagonal3x3BlockMatrix':
    # If diagonal is batched, then vmap over the batch dimension
    if diagonal.batch_size is not None:
      return jax.vmap(cls.from_diagonal)(diagonal)
    assert diagonal.shape[-1] % 3 == 0
    dim = diagonal.shape[-1] // 3
    A, B, C = diagonal.elements[:dim], diagonal.elements[dim:2*dim], diagonal.elements[2*dim:]
    zeros = jnp.zeros((dim,))
    return Diagonal3x3BlockMatrix(jnp.array([[A, zeros, zeros],
                                             [zeros, B, zeros],
                                             [zeros, zeros, C]]), tags=diagonal.tags)

  @auto_vmap
  def project_dense(self, dense: DenseMatrix) -> 'DenseMatrix':
    elements = dense.elements
    assert dense.shape[0] == dense.shape[1]
    assert dense.shape[0]%3 == 0
    dim = dense.shape[0]//3
    A11, A12, A13 = elements[:dim, :dim], elements[:dim, dim:2*dim], elements[:dim, 2*dim:]
    A21, A22, A23 = elements[dim:2*dim, :dim], elements[dim:2*dim, dim:2*dim], elements[dim:2*dim, 2*dim:]
    A31, A32, A33 = elements[2*dim:, :dim], elements[2*dim:, dim:2*dim], elements[2*dim:, 2*dim:]
    elements = jnp.array([[jnp.diag(A11), jnp.diag(A12), jnp.diag(A13)],
                           [jnp.diag(A21), jnp.diag(A22), jnp.diag(A23)],
                           [jnp.diag(A31), jnp.diag(A32), jnp.diag(A33)]])
    return Diagonal3x3BlockMatrix(elements, tags=dense.tags)

class ParametricSymmetricDiagonal3x3BlockMatrix(Diagonal3x3BlockMatrix):

  tags: Tags
  _elements: Float[Array, '3 3 N']

  def __init__(self, elements: Float[Array, '3 3 N']):
    self._elements = elements
    self.tags = TAGS.symmetric_tags

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self._elements.ndim > 4:
      return self._elements.shape[:-1]
    elif self._elements.ndim == 4:
      return self._elements.shape[0]
    elif self._elements.ndim == 3:
      return None
    else:
      raise ValueError(f"Invalid number of dimensions: {self._elements.ndim}")

  @property
  def elements(self) -> Float[Array, '3 3 N']:
    _elementsT = self._elements.swapaxes(-2, -3)
    return einops.einsum(_elementsT, self._elements, 'i j a, j k a -> i k a')

################################################################################################################

@dispatch
def make_parametric_symmetric_matrix(matrix: Diagonal3x3BlockMatrix) -> ParametricSymmetricDiagonal3x3BlockMatrix:
  return ParametricSymmetricDiagonal3x3BlockMatrix(matrix.get_cholesky().T.elements)

# @dispatch
# @symbolic_add
# def mat_add(A: Diagonal3x3BlockMatrix, B: Union[Scalar, float]) -> Diagonal3x3BlockMatrix:
#   return Diagonal3x3BlockMatrix(A.elements + B, tags=A.tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: Diagonal3x3BlockMatrix, B: Diagonal3x3BlockMatrix) -> Diagonal3x3BlockMatrix:
  new_tags = A.tags.add_update(B.tags)
  return Diagonal3x3BlockMatrix(A.elements + B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: Diagonal3x3BlockMatrix, B: DenseMatrix) -> DenseMatrix:
  new_tags = A.tags.add_update(B.tags)
  return DenseMatrix(A.as_matrix() + B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: Diagonal3x3BlockMatrix, B: DiagonalMatrix) -> Diagonal3x3BlockMatrix:
  B_block = Diagonal3x3BlockMatrix.from_diagonal(B)
  return mat_add(A, B_block)

@dispatch
@symbolic_add
def mat_add(A: Diagonal3x3BlockMatrix, B: Diagonal2x2BlockMatrix) -> DenseMatrix:
  return mat_add(A.to_dense(), B.to_dense())

@dispatch
@symbolic_add
def mat_add(A: DenseMatrix, B: Diagonal3x3BlockMatrix) -> DenseMatrix:
  return mat_add(B, A)

@dispatch
@symbolic_add
def mat_add(A: DiagonalMatrix, B: Diagonal3x3BlockMatrix) -> Diagonal3x3BlockMatrix:
  return mat_add(B, A)

@dispatch
@symbolic_add
def mat_add(A: Diagonal2x2BlockMatrix, B: Diagonal3x3BlockMatrix) -> DenseMatrix:
  return mat_add(A.to_dense(), B.to_dense())

################################################################################################################

@dispatch
@symbolic_scalar_mul
def scalar_mul(A: Diagonal3x3BlockMatrix, s: Scalar) -> Diagonal3x3BlockMatrix:
  new_tags = A.tags.scalar_mul_update()
  return Diagonal3x3BlockMatrix(s*A.elements, tags=new_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal3x3BlockMatrix, b: Float[Array, 'N']) -> Float[Array, 'M']:
  b_reshaped = b.reshape((3, -1))
  out = einops.einsum(A.elements, b_reshaped, 'i j a, j a -> i a').ravel()
  return jnp.where(A.tags.is_nonzero, out, jnp.zeros_like(out))

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal3x3BlockMatrix, B: Diagonal3x3BlockMatrix) -> Diagonal3x3BlockMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  elements = einops.einsum(A.elements, B.elements, 'i j a, j k a -> i k a')
  return Diagonal3x3BlockMatrix(elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal3x3BlockMatrix, B: DenseMatrix) -> DenseMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  return DenseMatrix(A.as_matrix()@B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal3x3BlockMatrix, B: DiagonalMatrix) -> Diagonal3x3BlockMatrix:
  B_block = Diagonal3x3BlockMatrix.from_diagonal(B)
  return mat_mul(A, B_block)

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal3x3BlockMatrix, B: Diagonal2x2BlockMatrix) -> DenseMatrix:
  return mat_mul(A.to_dense(), B.to_dense())

@dispatch
@symbolic_mat_mul
def mat_mul(A: DenseMatrix, B: Diagonal3x3BlockMatrix) -> DenseMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  return DenseMatrix(A.elements@B.as_matrix(), tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: DiagonalMatrix, B: Diagonal3x3BlockMatrix) -> Diagonal3x3BlockMatrix:
  A_block = Diagonal3x3BlockMatrix.from_diagonal(A)
  return mat_mul(A_block, B)

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal2x2BlockMatrix, B: Diagonal3x3BlockMatrix) -> DenseMatrix:
  return mat_mul(A.to_dense(), B.to_dense())

################################################################################################################

@dispatch
def transpose(A: Diagonal3x3BlockMatrix) -> Diagonal3x3BlockMatrix:
  out_elements = A.elements.swapaxes(-2, -3)
  out_tags = A.tags.transpose_update()
  return Diagonal3x3BlockMatrix(out_elements, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_solve
def matrix_solve(A: Diagonal3x3BlockMatrix,
                 b: Float[Array, 'N']) -> Float[Array, 'N']:
  Ainv = get_matrix_inverse(A)
  return mat_mul(Ainv, b)

@dispatch
@symbolic_solve
def matrix_solve(A: Diagonal3x3BlockMatrix,
                 B: Union[Diagonal3x3BlockMatrix, Diagonal2x2BlockMatrix, DenseMatrix, DiagonalMatrix]) -> Union[Diagonal3x3BlockMatrix, DenseMatrix]:
  Ainv = get_matrix_inverse(A)
  return mat_mul(Ainv, B)

@dispatch
@symbolic_solve
def matrix_solve(A: DenseMatrix, B: Diagonal3x3BlockMatrix) -> DenseMatrix:
  return matrix_solve(A, B.to_dense())

@dispatch
@symbolic_solve
def matrix_solve(A: DiagonalMatrix, B: Diagonal3x3BlockMatrix) -> Diagonal3x3BlockMatrix:
  A_block = Diagonal3x3BlockMatrix.from_diagonal(A)
  return matrix_solve(A_block, B)

@dispatch
@symbolic_solve
def matrix_solve(A: Diagonal2x2BlockMatrix, B: Diagonal3x3BlockMatrix) -> DenseMatrix:
  return matrix_solve(A.to_dense(), B.to_dense())

################################################################################################################

@dispatch
@symbolic_matrix_inverse
def get_matrix_inverse(A: Diagonal3x3BlockMatrix) -> Diagonal3x3BlockMatrix:
  m11, m12, m13, m21, m22, m23, m31, m32, m33 = A.elements.reshape((9, -1))

  det = m11 * m22 * m33 + m12 * m23 * m31 + m13 * m21 * m32 - m13 * m22 * m31 - m12 * m21 * m33 - m11 * m23 * m32
  inv_det = 1 / det

  inv_11 = (m22 * m33 - m23 * m32) * inv_det
  inv_12 = (m13 * m32 - m12 * m33) * inv_det
  inv_13 = (m12 * m23 - m13 * m22) * inv_det
  inv_21 = (m23 * m31 - m21 * m33) * inv_det
  inv_22 = (m11 * m33 - m13 * m31) * inv_det
  inv_23 = (m13 * m21 - m11 * m23) * inv_det
  inv_31 = (m21 * m32 - m22 * m31) * inv_det
  inv_32 = (m12 * m31 - m11 * m32) * inv_det
  inv_33 = (m11 * m22 - m12 * m21) * inv_det

  out_elements = jnp.array([[inv_11, inv_12, inv_13],
                        [inv_21, inv_22, inv_23],
                        [inv_31, inv_32, inv_33]])

  out_tags = A.tags.inverse_update()
  return Diagonal3x3BlockMatrix(out_elements, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_log_det
def get_log_det(A: Diagonal3x3BlockMatrix) -> Scalar:
  m11, m12, m13, m21, m22, m23, m31, m32, m33 = A.elements.reshape((9, -1))

  det = (m11 * (m22 * m33 - m23 * m32) -
          m12 * (m21 * m33 - m23 * m31) +
          m13 * (m21 * m32 - m22 * m31))

  log_det = jnp.sum(jnp.log(jnp.abs(det)))

  return log_det

################################################################################################################

@dispatch
@symbolic_cholesky
def get_cholesky(A: Diagonal3x3BlockMatrix) -> Diagonal3x3BlockMatrix:
  m11, m12, m13, m21, m22, m23, m31, m32, m33 = A.elements.reshape((9, -1))

  L11 = jnp.sqrt(m11)
  L21 = m21 / L11
  L31 = m31 / L11
  L22 = jnp.sqrt(m22 - L21**2)
  L32 = (m32 - L21*L31) / L22
  L33 = jnp.sqrt(m33 - L31**2 - L32**2)

  L = jnp.array([[L11, jnp.zeros_like(L21), jnp.zeros_like(L31)],
                [L21, L22, jnp.zeros_like(L32)],
                [L31, L32, L33]])

  out_tags = A.tags.cholesky_update()
  return Diagonal3x3BlockMatrix(L, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_exp
def get_exp(A: Diagonal3x3BlockMatrix) -> Diagonal3x3BlockMatrix:
  warnings.warn('Using inefficient dense matrix exponential for Diagonal3x3BlockMatrix')
  A = jax.scipy.linalg.expm(A.as_matrix())
  A_elements_dense = einops.rearrange(A, '(A H) (B W) -> A B H W', A=3, B=3)
  A_elements = jax.vmap(jax.vmap(jnp.diag))(A_elements_dense)
  out_tags = A.tags.exp_update()
  return Diagonal3x3BlockMatrix(A_elements, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_svd
def get_svd(A: Diagonal3x3BlockMatrix) -> Tuple[DenseMatrix, DiagonalMatrix, DenseMatrix]:
  U_elts, s_elts, V_elts = my_svd(A.as_matrix())
  U = DenseMatrix(U_elts, tags=TAGS.no_tags).fix_to_tags()
  s = DiagonalMatrix(s_elts, tags=TAGS.no_tags).fix_to_tags()
  V = DenseMatrix(V_elts, tags=TAGS.no_tags).fix_to_tags()
  return U, s, V

################################################################################################################

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  from diffusion_crf.matrix.matrix_base import matrix_tests

  # Turn on x64
  jax.config.update('jax_enable_x64', True)
  key = random.PRNGKey(0)

  # Dense matrix tests
  k1, k2 = random.split(key)
  A = random.normal(k1, (3, 3, 4))
  B = random.normal(k2, (3, 3, 4))
  A = Diagonal3x3BlockMatrix(A, tags=TAGS.no_tags)
  B = Diagonal3x3BlockMatrix(B, tags=TAGS.no_tags)
  matrix_tests(key, A, B)

  # Check that zero matrices are handled correctly
  A = Diagonal3x3BlockMatrix.zeros_like(A)
  matrix_tests(key, A, B)
  import pdb; pdb.set_trace()