from typing import Tuple

import flax.linen as nn

from egxc.systems.base import System
from egxc.utils.typing import (
    Float2xBxB,
    FloatBxB,
    FloatSCF,
    FloatSCFx2xBxB,
    FloatSCFxBxB,
)
from egxc.xc_energy import XCModule


class Solver(nn.Module):
    """
    Abstract base class for all solvers.
    """

    xc_module: XCModule

    def __call__(
        self,
        initial_density_matrix: FloatBxB | Float2xBxB,
        sys: System,
    ) -> Tuple[Tuple[FloatSCF, FloatSCF], FloatSCFxBxB | FloatSCFx2xBxB]: ...
