import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Type
import einops
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
import lineax as lx
import abc
import warnings
import jax.tree_util as jtu
from diffusion_crf.matrix.matrix_base import *
from diffusion_crf.base import auto_vmap
from plum import dispatch
from diffusion_crf.matrix.dense import DenseMatrix
from diffusion_crf.matrix.diagonal import DiagonalMatrix
from diffusion_crf.matrix.symbolic import *
from diffusion_crf.matrix.tags import Tags, TAGS
from diffusion_crf.util.svd import my_svd

class Diagonal2x2BlockMatrix(AbstractMatrix):

  tags: Tags
  elements: Float[Array, '2 2 N'] # The elements of the matrix

  def __init__(
      self,
      elements: Float[Array, '2 2 N'],
      tags: Tags
  ):
    self.elements = elements
    self.tags = tags

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self.elements.ndim > 4:
      return self.elements.shape[:-3]
    elif self.elements.ndim == 4:
      return self.elements.shape[0]
    elif self.elements.ndim == 3:
      return None
    else:
      raise ValueError(f"Invalid number of dimensions: {self.elements.ndim}")

  @property
  def shape(self):
    batch_shape = self.elements.shape[:-3]
    return batch_shape + (2*self.elements.shape[-1], 2*self.elements.shape[-1])

  @classmethod
  def zeros(cls, shape: Tuple[int, ...]) -> 'Diagonal2x2BlockMatrix':
    dim = shape[-1]
    assert shape[-2] == shape[-1] == dim
    assert dim%2 == 0
    elements = jnp.zeros((2, 2, dim//2))
    return Diagonal2x2BlockMatrix(elements, tags=TAGS.zero_tags)

  @classmethod
  def eye(cls, dim: int) -> 'Diagonal2x2BlockMatrix':
    return cls.from_diagonal(DiagonalMatrix.eye(dim))

  @auto_vmap
  def as_matrix(self):
    # Construct a block diagonal matrix from self.elements that has a height of H*C and width of W*C
    # where each block is a diagonal matrix of size CxC that has the entries self.elements[i,j,:]
    H, W, C = self.elements.shape[-3:]
    if len(self.elements.shape) > 3:
      block_diagonal = jnp.zeros((self.elements.shape[:-3] + (H*C, W*C)))
    else:
      block_diagonal = jnp.zeros((H*C, W*C))

    for i in range(H):
      for j in range(W):
        block_start_row = i * C
        block_start_col = j * C
        block = jnp.diag(self.elements[...,i, j, :])
        block_diagonal = block_diagonal.at[...,block_start_row:block_start_row+C, block_start_col:block_start_col+C].set(block)

    return block_diagonal

  @auto_vmap
  def __neg__(self) -> 'Diagonal2x2BlockMatrix':
    return Diagonal2x2BlockMatrix(-self.elements, tags=self.tags)

  @auto_vmap
  def to_dense(self) -> DenseMatrix:
    return DenseMatrix(self.as_matrix(), tags=self.tags)

  @classmethod
  def from_diagonal(cls, diagonal: DiagonalMatrix) -> 'Diagonal2x2BlockMatrix':
    # If diagonal is batched, then vmap over the batch dimension
    if diagonal.batch_size is not None:
      return jax.vmap(cls.from_diagonal)(diagonal)
    assert diagonal.shape[-1]%2 == 0
    dim = diagonal.shape[-1]//2
    A, D = diagonal.elements[:dim], diagonal.elements[dim:]
    B, C = jnp.zeros((dim,)), jnp.zeros((dim,))
    return Diagonal2x2BlockMatrix(jnp.array([[A, B],
                                             [C, D]]),
                                             tags=diagonal.tags)

  @auto_vmap
  def project_dense(self, dense: DenseMatrix) -> 'DenseMatrix':
    elements = dense.elements
    assert dense.shape[0] == dense.shape[1]
    assert dense.shape[0]%2 == 0
    dim = dense.shape[0]//2
    A, B, C, D = elements[:dim, :dim], elements[:dim, dim:], elements[dim:, :dim], elements[dim:, dim:]
    A_diag = jnp.diag(A)
    D_diag = jnp.diag(D)
    B_diag = jnp.diag(B)
    C_diag = jnp.diag(C)
    elements = jnp.array([[A_diag, B_diag],
                          [C_diag, D_diag]])
    out = Diagonal2x2BlockMatrix(elements, tags=dense.tags)
    return out

class ParametricSymmetricDiagonal2x2BlockMatrix(Diagonal2x2BlockMatrix):

  tags: Tags
  _elements: Float[Array, '2 2 N']

  def __init__(self, elements: Float[Array, '2 2 N']):
    self._elements = elements
    self.tags = TAGS.symmetric_tags

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self._elements.ndim > 4:
      return self._elements.shape[:-1]
    elif self._elements.ndim == 4:
      return self._elements.shape[0]
    elif self._elements.ndim == 3:
      return None
    else:
      raise ValueError(f"Invalid number of dimensions: {self._elements.ndim}")

  @property
  @auto_vmap
  def elements(self) -> Float[Array, '2 2 N']:
    diag_idx = jnp.array([0, 1])
    diag_elements = jnp.abs(self._elements[diag_idx, diag_idx]) + 1e-8
    _elements = self._elements.at[diag_idx,diag_idx].set(diag_elements)

    _elementsT = _elements.swapaxes(-2, -3)
    return einops.einsum(_elementsT, _elements, 'i j a, j k a -> i k a')

################################################################################################################

@dispatch
def make_parametric_symmetric_matrix(matrix: Diagonal2x2BlockMatrix) -> ParametricSymmetricDiagonal2x2BlockMatrix:
  return ParametricSymmetricDiagonal2x2BlockMatrix(matrix.get_cholesky().T.elements)

# @dispatch
# @symbolic_add
# def mat_add(A: Diagonal2x2BlockMatrix, B: Union[Scalar, float]) -> Diagonal2x2BlockMatrix:
#   return Diagonal2x2BlockMatrix(A.elements + B, tags=A.tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: Diagonal2x2BlockMatrix, B: Diagonal2x2BlockMatrix) -> Diagonal2x2BlockMatrix:
  new_tags = A.tags.add_update(B.tags)
  return Diagonal2x2BlockMatrix(A.elements + B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: Diagonal2x2BlockMatrix, B: DenseMatrix) -> DenseMatrix:
  new_tags = A.tags.add_update(B.tags)
  return DenseMatrix(A.as_matrix() + B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_add
def mat_add(A: DenseMatrix, B: Diagonal2x2BlockMatrix) -> DenseMatrix:
  return mat_add(B, A)

@dispatch
@symbolic_add
def mat_add(A: Diagonal2x2BlockMatrix, B: DiagonalMatrix) -> Diagonal2x2BlockMatrix:
  B_block = Diagonal2x2BlockMatrix.from_diagonal(B)
  return mat_add(A, B_block)

@dispatch
@symbolic_add
def mat_add(A: DiagonalMatrix, B: Diagonal2x2BlockMatrix) -> Diagonal2x2BlockMatrix:
  return mat_add(B, A)

################################################################################################################

@dispatch
@symbolic_scalar_mul
def scalar_mul(A: Diagonal2x2BlockMatrix, s: Scalar) -> Diagonal2x2BlockMatrix:
  new_tags = A.tags.scalar_mul_update()
  return Diagonal2x2BlockMatrix(s*A.elements, tags=new_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal2x2BlockMatrix, B: Diagonal2x2BlockMatrix) -> Diagonal2x2BlockMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  elements = einops.einsum(A.elements, B.elements, 'i j a, j k a -> i k a')
  return Diagonal2x2BlockMatrix(elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal2x2BlockMatrix, B: DenseMatrix) -> DenseMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  return DenseMatrix(A.as_matrix()@B.elements, tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: DenseMatrix, B: Diagonal2x2BlockMatrix) -> DenseMatrix:
  new_tags = A.tags.mat_mul_update(B.tags)
  return DenseMatrix(A.elements@B.as_matrix(), tags=new_tags).fix_to_tags()

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal2x2BlockMatrix, b: Float[Array, 'N']) -> Float[Array, 'M']:
  b_reshaped = b.reshape((2, -1))
  return einops.einsum(A.elements, b_reshaped, 'i j a, j a -> i a').ravel()

@dispatch
@symbolic_mat_mul
def mat_mul(A: Diagonal2x2BlockMatrix, B: DiagonalMatrix) -> Diagonal2x2BlockMatrix:
  B_block = Diagonal2x2BlockMatrix.from_diagonal(B)
  return mat_mul(A, B_block)

@dispatch
@symbolic_mat_mul
def mat_mul(A: DiagonalMatrix, B: Diagonal2x2BlockMatrix) -> Diagonal2x2BlockMatrix:
  A_block = Diagonal2x2BlockMatrix.from_diagonal(A)
  return mat_mul(A_block, B)

################################################################################################################

@dispatch
def transpose(A: Diagonal2x2BlockMatrix) -> Diagonal2x2BlockMatrix:
  out_elements = A.elements.swapaxes(-2, -3)
  out_tags = A.tags.transpose_update()
  return Diagonal2x2BlockMatrix(out_elements, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_solve
def matrix_solve(A: Diagonal2x2BlockMatrix,
                 b: Float[Array, 'N']) -> Float[Array, 'N']:
  Ainv = get_matrix_inverse(A)
  return mat_mul(Ainv, b)

@dispatch
@symbolic_solve
def matrix_solve(A: Diagonal2x2BlockMatrix,
                 B: Union[Diagonal2x2BlockMatrix, DenseMatrix, DiagonalMatrix]) -> Union[Diagonal2x2BlockMatrix, DenseMatrix]:
  Ainv = get_matrix_inverse(A)
  return mat_mul(Ainv, B)

@dispatch
@symbolic_solve
def matrix_solve(A: DenseMatrix, B: Diagonal2x2BlockMatrix) -> DenseMatrix:
  return matrix_solve(A, B.to_dense())

@dispatch
@symbolic_solve
def matrix_solve(A: DiagonalMatrix, B: Diagonal2x2BlockMatrix) -> Diagonal2x2BlockMatrix:
  A_block = Diagonal2x2BlockMatrix.from_diagonal(A)
  return matrix_solve(A_block, B)

################################################################################################################

@dispatch
@symbolic_matrix_inverse
def get_matrix_inverse(A: Diagonal2x2BlockMatrix) -> Diagonal2x2BlockMatrix:
  a, b, c, d = A.elements.reshape((4, -1))
  top_left = 1/(a - b/d*c)
  top_right = -top_left*b/d
  bottom_left = -c/d*top_left
  bottom_right = 1/(d - c/a*b)

  out_elements = jnp.array([[top_left, top_right],
                      [bottom_left, bottom_right]])
  out_tags = A.tags.inverse_update()
  return Diagonal2x2BlockMatrix(out_elements, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_log_det
def get_log_det(A: Diagonal2x2BlockMatrix) -> Scalar:
  a, b, c, d = A.elements.reshape((4, -1))
  term1 = jnp.log(jnp.abs(a)).sum()
  term2 = jnp.log(jnp.abs(d - c/a*b)).sum()
  log_det = term1 + term2
  return log_det

################################################################################################################

@dispatch
@symbolic_cholesky
def get_cholesky(A: Diagonal2x2BlockMatrix) -> Diagonal2x2BlockMatrix:
  a, b, bT, c = A.elements.reshape((4, -1))
  L11 = jnp.sqrt(a)
  L21 = b/L11
  L22 = jnp.sqrt(c - L21**2)
  L = jnp.array([[L11, jnp.zeros_like(L21)],
                  [L21, L22]])
  L = jnp.where(A.tags.is_nonzero, L, jnp.zeros_like(L))
  out_tags = A.tags.cholesky_update()
  return Diagonal2x2BlockMatrix(L, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_exp
def get_exp(A: Diagonal2x2BlockMatrix) -> Diagonal2x2BlockMatrix:
  warnings.warn('Using inefficient dense matrix exponential for Diagonal2x2BlockMatrix')
  A = jax.scipy.linalg.expm(A.as_matrix())
  A_elements_dense = einops.rearrange(A, '(A H) (B W) -> A B H W', A=2, B=2)
  A_elements = jax.vmap(jax.vmap(jnp.diag))(A_elements_dense)
  out_tags = A.tags.exp_update()
  return Diagonal2x2BlockMatrix(A_elements, tags=out_tags).fix_to_tags()

################################################################################################################

@dispatch
@symbolic_svd
def get_svd(A: Diagonal2x2BlockMatrix) -> Tuple[DenseMatrix, DiagonalMatrix, DenseMatrix]:
  U_elts, s_elts, V_elts = my_svd(A.as_matrix())
  U = DenseMatrix(U_elts, tags=TAGS.no_tags).fix_to_tags()
  s = DiagonalMatrix(s_elts, tags=TAGS.no_tags).fix_to_tags()
  V = DenseMatrix(V_elts, tags=TAGS.no_tags).fix_to_tags()
  return U, s, V

################################################################################################################

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  from diffusion_crf.matrix.matrix_base import matrix_tests

  # Turn on x64
  jax.config.update('jax_enable_x64', True)
  key = random.PRNGKey(0)

  # Dense matrix tests
  k1, k2 = random.split(key)

  def init_mat(key):
    A = random.normal(key, (2, 2, 4))
    A = Diagonal2x2BlockMatrix(A, tags=TAGS.no_tags)
    A = A.T@A
    return ParametricSymmetricDiagonal2x2BlockMatrix(A.elements)

  sym = init_mat(key)
  elts = sym.elements
  import pdb; pdb.set_trace()

  keys = random.split(key, 100)
  A = jax.vmap(init_mat)(keys)

  U, s, V = A.get_svd()

  import pdb; pdb.set_trace()





  A = random.normal(k1, (2, 2, 4))
  B = random.normal(k2, (2, 2, 4))
  A = Diagonal2x2BlockMatrix(A, tags=TAGS.no_tags)
  B = Diagonal2x2BlockMatrix(B, tags=TAGS.no_tags)
  matrix_tests(key, A, B)

  # Check that zero matrices are handled correctly
  A = Diagonal2x2BlockMatrix.zeros_like(A)
  matrix_tests(key, A, B)
  import pdb; pdb.set_trace()