from typing import Literal, Tuple, Union
from petsc4py import PETSc
from slepc4py import SLEPc
import numpy as np
from scipy import sparse
import logging


logger = logging.getLogger(__name__)


def scipysparse_to_PETSc(A: sparse.coo_matrix) -> PETSc.Mat:
    """Convert a scipy sparse matrix to a distributed PETSc matrix."""
    M, N = A.shape

    mat = PETSc.Mat().create(comm=PETSc.COMM_WORLD)
    mat.setSizes([M, N])
    mat.setType("aij")  # Parallel sparse matrix
    # nnz_per_row = np.bincount(A.row, minlength=M).astype(np.int32)
    # mat.setPreallocationNNZ(nnz_per_row)
    mat.setUp()

    rows, cols = A.row, A.col
    data = A.data

    for i in range(len(data)):
        mat.setValue(rows[i], cols[i], data[i], addv=False)

    mat.assemble()
    return mat


from petsc4py import PETSc
from slepc4py import SLEPc
from scipy import sparse
import numpy as np
from typing import Tuple, Union, Literal


def SLEPc_eigsolve(
    A: sparse.coo_matrix,
    k: Union[int, None],
    which: Literal["smallest", "largest"],
) -> Tuple[np.ndarray, np.ndarray]:

    rank = PETSc.COMM_WORLD.getRank()
    size = PETSc.COMM_WORLD.getSize()

    # Convert to PETSc Mat
    logger.info("Creating PETSc matrix")
    mat = scipysparse_to_PETSc(A)
    logger.info("Done")
    N = A.shape[0]

    E = SLEPc.EPS()
    E.create(comm=PETSc.COMM_WORLD)
    E.setOperators(mat)
    E.setDimensions(nev=k)
    E.setProblemType(SLEPc.EPS.ProblemType.HEP)
    E.setTolerances(tol=1e-4, max_it=600)

    if which == "largest":
        E.setWhichEigenpairs(SLEPc.EPS.Which.LARGEST_REAL)
    elif which == "smallest":
        E.setWhichEigenpairs(SLEPc.EPS.Which.SMALLEST_REAL)
    else:
        raise ValueError(f"Invalid value for 'which': {which}")

    logger.info("Starting eigensolver")
    E.solve()

    if rank == 0:
        print("\n*** SLEPc Solution Results ***\n")
        print(f"Method: {E.getType()}")
        print(f"Iterations: {E.getIterationNumber()}")
        print(f"Converged eigenpairs: {E.getConverged()}")

    nconv = E.getConverged()
    assert nconv >= k, f"Only {nconv} eigenvectors converged, expected {k}"

    evals = np.zeros(k)
    evecs = np.zeros((N, k)) if rank == 0 else None

    vr, _ = mat.getVecs()
    vi, _ = mat.getVecs()

    for i in range(k):
        eigval = E.getEigenpair(i, vr, vi)
        evals[i] = eigval.real

        # Scatter distributed vector
        scatter, out_vec = PETSc.Scatter().toZero(vr)
        scatter.scatter(vr, out_vec, addv=PETSc.InsertMode.INSERT_VALUES)
        if rank == 0:
            evecs[:, i] = out_vec.array.copy()
        scatter.destroy()

    vr.destroy()
    vi.destroy()
    E.destroy()

    if rank == 0:
        return evals, evecs
    else:
        return None, None
