"""
Matrix casting utilities to convert between different matrix types.

This module provides a unified approach to casting matrices between different
types in the matrix package, respecting the established precedence rules.
"""

from typing import TypeVar, Union, Type, Optional, Any
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from .matrix_base import AbstractMatrix
from .dense import DenseMatrix
from .diagonal import DiagonalMatrix
from .diagonal_block_2x2 import Diagonal2x2BlockMatrix
from .diagonal_block_3x3 import Diagonal3x3BlockMatrix
from .block_2x2 import Block2x2Matrix

T = TypeVar('T', bound=AbstractMatrix)
MatrixType = Type[AbstractMatrix]

# Define the type precedence order (lower index = higher precedence)
MATRIX_PRECEDENCE = [
    DenseMatrix,               # Highest precedence
    Block2x2Matrix,            # General 2x2 block matrix
    Diagonal3x3BlockMatrix,    # 3x3 block of diagonal matrices
    Diagonal2x2BlockMatrix,    # 2x2 block of diagonal matrices
    DiagonalMatrix,            # Lowest precedence
]

def get_precedence(matrix_type: MatrixType) -> int:
    """
    Get the precedence level of a matrix type.
    
    Lower number means higher precedence.
    
    Args:
        matrix_type: The matrix type to check
        
    Returns:
        The precedence level as an integer
    """
    try:
        return MATRIX_PRECEDENCE.index(matrix_type)
    except ValueError:
        # If not found, default to lowest precedence
        return len(MATRIX_PRECEDENCE)

def highest_precedence_type(types: list[MatrixType]) -> MatrixType:
    """
    Find the matrix type with the highest precedence from a list.
    
    Args:
        types: List of matrix types
        
    Returns:
        The matrix type with highest precedence
    """
    if not types:
        raise ValueError("Empty list of types provided")
    
    return min(types, key=get_precedence)

def cast_matrix(
    matrix: AbstractMatrix,
    target_type: Union[MatrixType, AbstractMatrix, None] = None,
    allow_downcast: bool = False
) -> AbstractMatrix:
    """
    Cast a matrix to a specified type or to a compatible type for operations.
    
    This function implements the casting logic based on the precedence rules
    established in the matrix package. If no target_type is specified, it
    will return the original matrix.
    
    Args:
        matrix: The matrix to cast
        target_type: The target matrix type or a matrix instance to cast to.
                    If None, returns the original matrix.
        allow_downcast: Whether to allow casting to a lower precedence type.
                        Default is False which prevents potential loss of information.
                        
    Returns:
        The matrix cast to the target type
    """
    if matrix is None:
        raise ValueError("Cannot cast None to a matrix type")
    
    # If target_type is None, return the original
    if target_type is None:
        return matrix
    
    # If target_type is a matrix instance, get its type
    if isinstance(target_type, AbstractMatrix):
        target_type_cls = type(target_type)
    else:
        target_type_cls = target_type
    
    # If matrix is already of the target type, return it
    if isinstance(matrix, target_type_cls):
        return matrix
    
    matrix_type_cls = type(matrix)
    
    # Get precedence levels
    matrix_precedence = get_precedence(matrix_type_cls)
    target_precedence = get_precedence(target_type_cls)
    
    # Check if this is a downcast (higher to lower precedence)
    if matrix_precedence < target_precedence and not allow_downcast:
        raise ValueError(
            f"Cannot downcast from {matrix_type_cls.__name__} "
            f"to {target_type_cls.__name__} without allow_downcast=True"
        )
    
    # Get dense representation first if needed
    if not isinstance(matrix, DenseMatrix):
        # Convert to dense if it has a to_dense method
        if hasattr(matrix, 'to_dense'):
            dense_matrix = matrix.to_dense()
        else:
            # Should never happen with our matrix types, but just in case
            raise ValueError(f"Cannot convert {type(matrix).__name__} to dense format")
    else:
        dense_matrix = matrix  # Already dense
    
    # Special case handling for each target type
    if target_type_cls == DenseMatrix:
        # Already dense or converted to dense
        return dense_matrix
    
    elif target_type_cls == Block2x2Matrix:
        if isinstance(matrix, DiagonalMatrix):
            return Block2x2Matrix.from_diagonal(matrix)
        else:
            # Create an empty instance to call project_dense on
            empty_instance = Block2x2Matrix.zeros(dense_matrix.shape[0])
            return empty_instance.project_dense(dense_matrix)
    
    elif target_type_cls == Diagonal3x3BlockMatrix:
        if isinstance(matrix, DiagonalMatrix):
            return Diagonal3x3BlockMatrix.from_diagonal(matrix)
        else:
            # Create an empty instance to call project_dense on
            empty_instance = Diagonal3x3BlockMatrix.zeros(dense_matrix.shape[0])
            return empty_instance.project_dense(dense_matrix)
    
    elif target_type_cls == Diagonal2x2BlockMatrix:
        if isinstance(matrix, DiagonalMatrix):
            return Diagonal2x2BlockMatrix.from_diagonal(matrix)
        else:
            # Create an empty instance to call project_dense on
            empty_instance = Diagonal2x2BlockMatrix.zeros(dense_matrix.shape[0])
            return empty_instance.project_dense(dense_matrix)
    
    elif target_type_cls == DiagonalMatrix:
        # Create an empty instance to call project_dense on
        empty_instance = DiagonalMatrix.zeros(dense_matrix.shape[0])
        return empty_instance.project_dense(dense_matrix)
    
    else:
        # Fallback: Use the cast_like method from AbstractMatrix
        if isinstance(target_type, AbstractMatrix):
            return matrix.cast_like(target_type)
        else:
            # Create a zero matrix of the target type to use as template
            # This assumes all matrix types have a zeros_like constructor
            dim = matrix.dim
            zero_matrix = target_type_cls.zeros(dim, tag_module=matrix.tags)
            return matrix.cast_like(zero_matrix)

def cast_compatible(
    matrix1: AbstractMatrix, 
    matrix2: AbstractMatrix
) -> tuple[AbstractMatrix, AbstractMatrix]:
    """
    Cast two matrices to be compatible for operations.
    
    This function will cast both matrices to the type with highest precedence,
    ensuring they can be used together in operations like addition or subtraction.
    
    Args:
        matrix1: First matrix
        matrix2: Second matrix
        
    Returns:
        Tuple of (cast_matrix1, cast_matrix2)
    """
    m1_type = type(matrix1)
    m2_type = type(matrix2)
    
    # If already same type, no casting needed
    if m1_type == m2_type:
        return matrix1, matrix2
    
    # Find the highest precedence type
    target_type = highest_precedence_type([m1_type, m2_type])
    
    # Cast both matrices to the target type
    cast_matrix1 = cast_matrix(matrix1, target_type)
    cast_matrix2 = cast_matrix(matrix2, target_type)
    
    return cast_matrix1, cast_matrix2