from abc import abstractclassmethod, abstractmethod
from typing import (
    get_args,
    Callable,
    Literal,
    TypeAlias,
    NamedTuple,
)

import symo.special_matrix as sm
from symo.group import I, S, B, O, Eq
import numpy as np
import torch
import plum

dispatch = plum.Dispatcher()
dispatch_init = plum.Dispatcher()


def contract(subscript: str, *operands):
    """Einstein summation using PyTorch"""
    return torch.einsum(subscript, *operands)


NDArray = torch.Tensor | np.ndarray
InitFn: TypeAlias = Callable[[tuple], NDArray] | None
WithInit: TypeAlias = Literal["with_init"]
ClassOnly: TypeAlias = Literal["class_only"]
FromCov: TypeAlias = Literal["cov_source"]
FromParam: TypeAlias = Literal["param_source"]
FromFactor: TypeAlias = Literal["factor_source"]
FromSource: TypeAlias = FromCov | FromParam | FromFactor | WithInit | ClassOnly

Value: TypeAlias = NDArray | tuple[NDArray, ...] | InitFn


def factor_class(groups: Eq):
    class_only = get_args(ClassOnly)[0]
    return factor(groups, source=class_only, value=None)


def factor_init(groups: Eq, init_fn: InitFn | None = None):
    with_init = get_args(WithInit)[0]
    return factor(groups, source=with_init, value=init_fn)


def factor_from_cov(groups: Eq, cov):
    from_cov = get_args(FromCov)[0]
    return factor(groups, source=from_cov, value=cov)


def factor_from_param(groups: Eq, param: NDArray | tuple):
    from_param = get_args(FromParam)[0]
    return factor(groups, source=from_param, value=param)


def _factor_from_source(cls, src: FromSource, eq: Eq, *args, **kwargs):
    if src == "class_only":
        return cls
    elif src == "with_init":
        return cls.from_init_fn(eq, *args, **kwargs)
    elif src == "param_source":
        return cls.from_param(eq, *args, **kwargs)
    elif src == "cov_source":
        return cls.from_cov(eq, *args, **kwargs)
    raise NotImplementedError(f"'{src}' isn't implemented.")


# Factor base class


class FactorBase(torch.nn.Module):
    eq: Eq

    def __init__(self, eq):
        super().__init__()
        self.eq = eq

    @abstractmethod
    def cov(self, surrogate: bool = False) -> torch.Tensor: ...

    @abstractmethod
    def matvec(self, vec: NDArray, transpose: bool = False) -> torch.Tensor: ...


class Factor(FactorBase):
    eq: Eq
    weights: torch.Tensor

    def __init__(self, eq, weights):
        super().__init__(eq)
        self.register_buffer("weights", weights)

    @classmethod
    def from_param(cls, eq, params: tuple[torch.Tensor, torch.Tensor]):
        weight = cls.outer_estimate(eq, params)
        return cls(eq, weight)

    @classmethod
    def from_cov(cls, eq, cov: torch.Tensor):
        weight = cls.cov_estimate(eq, cov)
        return cls(eq, weight)

    def outer_estimate_(self, lhs, rhs):
        pair = (lhs, rhs)
        weight = self.outer_estimate(self.eq, pair)
        self.factor.weights.copy_(weight)

    def cov_estimate_(self, cov, surrogate: bool = False):
        weight = self.cov_estimate(self.eq, cov, surrogate=surrogate)
        self.weights.copy_(weight)

    @abstractclassmethod
    def outer_estimate(
        cls, eq: Eq, vectors: tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor: ...

    @abstractclassmethod
    def cov_estimate(
        cls, eq: Eq, cov: torch.Tensor, surrogate: bool = False
    ) -> torch.Tensor: ...


# Factors dispatcher


@plum.overload
def factor(
    eq: Eq[O, O] | Eq[B, B],  # type: ignore
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(OOrB_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[S, S],  # type: ignore
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[S["N"], S["M"]],  # type: ignore  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[I["N"], I["M"]] | Eq[I, I],  # type: ignore
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[S["N"], I] | Eq[S, I],  # type: ignore  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[I, S["M"]] | Eq[I, S],  # type: ignore  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(I, I), (I, I)] | Eq[(I["N"], I), (I["N"], I)] | Eq[(I, I["N"]), (I, I["N"])],  # type: ignore  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(O["N"], O), (O["N"], O)]
        | Eq[(O["N"], B), (O["N"], B)]
        | Eq[(B["N"], O), (B["N"], O)]
        | Eq[(B["N"], B), (B["N"], B)]
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(OnOrBn_OmOrBm_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(B, B), (B, B)],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Bn_Bn_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(O, O), (O, O)],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(On_On_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(O["N"], S), (O["N"], S)] | Eq[(B["N"], S), (B["N"], S)],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(OOrB_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S["N"], O), (S["N"], O)] | Eq[(S["N"], B), (S["N"], B)],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_OOrB_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S["N"], O), (O, S["N"])] | Eq[(S["N"], B), (B, S["N"])],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_OOrB_OOrB_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S["N"], S), (S["N"], S)],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_Sn_Sm_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S["N"], S), (S, S["N"])],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_Sm_Sn_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S["N"], S["M"]), (S["N"], S["L"])],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_Sn_Sl_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S["N"], S["M"]), (S["K"], S["M"])],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_Sk_Sm_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S["N"], S["M"]), (S["M"], S["L"])],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_Sm_Sk_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S, S), (S, S)],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sn_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(I["N"], O), (I["N"], O)]
        | Eq[(I["N"], O), (I["L"], O)]
        | Eq[(I["N"], B), (I["N"], B)]
        | Eq[(I["N"], B), (I["L"], B)]
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_OorB_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(O["N"], I), (O["N"], I)]
        | Eq[(O["N"], I), (O["N"], I["K"])]
        | Eq[(B["N"], I), (B["N"], I)]
        | Eq[(B["N"], I), (B["N"], I["K"])]
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(OorB_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(I["N"], S), (I["N"], S)] | Eq[(I["N"], S), (I["L"], S)],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_S_I_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S["N"], I), (S["N"], I)] | Eq[(S["N"], I), (S["N"], I["K"])],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_I_S_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(I, O["M"]), (O["M"], I)]
        | Eq[(I, B["M"]), (B["M"], I)]
        | Eq[(I, O["M"]), (O["M"], I["K"])]
        | Eq[(I, B["M"]), (B["M"], I["K"])]
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_OnorBn_OnorBn_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(O["M"], I), (I, O["M"])]  # type: ignore  # noqa: F821
        | Eq[(B["M"], I), (I, B["M"])]  # type: ignore  # noqa: F821
        | Eq[(O["M"], I), (I["K"], O["M"])]  # type: ignore  # noqa: F821
        | Eq[(B["M"], I), (I["K"], B["M"])]  # type: ignore  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(OnorBn_I_I_OnorBn_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(I, S["M"]), (S["M"], I)] | Eq[(I, S["M"]), (S["M"], I["L"])],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_Sm_Sm_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(I, S), (S["K"], I)]
        | Eq[(I, S["M"]), (S, I)]
        | Eq[(I, S["M"]), (S["K"], I)]
        | Eq[(I, S["M"]), (S["K"], I["L"])]
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_Sm_Sk_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S, I["M"]), (I["K"], S)],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_I_I_Sn_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(S, I), (S["K"], I)]
        | Eq[(S["N"], I), (S, I)]
        | Eq[(S["N"], I), (S["K"], I)]
        | Eq[(S, I["M"]), (S["K"], I["L"])]
    ),  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_I_Sk_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(S, I), (I, S["L"])]
        | Eq[(S["N"], I), (I, S)]
        | Eq[(S["N"], I), (I, S["L"])]
        | Eq[(S, I["M"]), (I["K"], S["L"])]
    ),  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_I_I_Sl_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(O, S["M"]), (O, I)]
        | Eq[(O, S["M"]), (O, I["L"])]
        | Eq[(B, S["M"]), (B, I)]
        | Eq[(B, S["M"]), (B, I["L"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(OOrB_S_OOrB_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(O, S["M"]), (I, O)]
        | Eq[(O, S["M"]), (I["K"], O)]
        | Eq[(B, S["M"]), (I, B)]
        | Eq[(B, S["M"]), (I["L"], B)]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(OOrB_S_I_OOrB_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S, S["M"]), (S, I["L"])] | Eq[(S, S["M"]), (S, I)],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_Sm_S_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S, S["M"]), (I["K"], S)] | Eq[(S, S["M"]), (I, S)],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_Sm_I_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(S["N"], O), (O, I)]
        | Eq[(S["N"], O), (O, I["K"])]
        | Eq[(S["N"], O), (O, I["N"])]
        | Eq[(S["N"], B), (B, I)]
        | Eq[(S["N"], B), (B, I["K"])]
        | Eq[(S["N"], B), (B, I["N"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_OOrB_OOrB_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(S["N"], O), (I, O)]
        | Eq[(S["N"], O), (I["K"], O)]
        | Eq[(S["N"], O), (I["N"], O)]
        | Eq[(S["N"], B), (I, B)]
        | Eq[(S["N"], B), (I["K"], B)]
        | Eq[(S["N"], B), (I["N"], B)]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_OOrB_I_OOrB_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(S["N"], S), (S, I)]
        | Eq[(S["N"], S), (S, I["N"])]
        | Eq[(S["N"], S), (S, I["L"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_S_S_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(S["N"], S), (I, S)]
        | Eq[(S["N"], S), (I["N"], S)]
        | Eq[(S["N"], S), (I["K"], S)]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_S_I_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S, S), (S, I["K"])] | Eq[(S, S), (S, I)],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_S_S_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S, S), (I["L"], S)] | Eq[(S, S), (I, S)],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_S_I_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(S["N"], S["M"]), (S["K"], I)]
        | Eq[(S["N"], S["M"]), (S["K"], I["N"])]
        | Eq[(S["N"], S["M"]), (S["K"], I["M"])]
        | Eq[(S["N"], S["M"]), (S["K"], I["K"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_Sk_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[(S["N"], S["M"]), (I, S["L"])]
        | Eq[(S["N"], S["M"]), (I["N"], S["L"])]
        | Eq[(S["N"], S["M"]), (I["M"], S["L"])]
        | Eq[(S["N"], S["M"]), (I["L"], S["L"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_I_Sl_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[(S["N"], S["M"]), (S["K"], S["L"])],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_Sk_Sl_Factor, source, eq, value)


# Asymmetric [N, M, L] shapes


@plum.overload
def factor(
    # groups: Eq[I, (I, S)],
    eq: (
        Eq[I, (I, S)]
        | Eq[I["N"], (I, S)]
        | Eq[I, (I["M"], S)]
        | Eq[I["N"], (I["M"], S)]
        | Eq[I["N"], (I["N"], S)]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_I_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[I, (S, I)]
        | Eq[I["N"], (S, I)]
        | Eq[I, (S, I["L"])]
        | Eq[I["N"], (S, I["N"])]
        | Eq[I["N"], (S, I["L"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_S_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[S, (I, I)]
        | Eq[S, (I["M"], I)]
        | Eq[S, (I, I["L"])]
        | Eq[S, (I["M"], I["L"])]
        | Eq[S, (I["M"], I["M"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_I_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[S["N"], (I, S["L"])]
        | Eq[S["N"], (I["N"], S["L"])]
        | Eq[S["N"], (I["L"], S["L"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_I_Sl_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[S["N"], (S["M"], I)]
        | Eq[S["N"], (S["M"], I["N"])]
        | Eq[S["N"], (S["M"], I["M"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_I_Factor, source, eq, value)


@plum.overload
def factor(
    # groups: Eq[I, tuple[Sm, S]],
    eq: (
        # Eq[I, (S, S)] # bla: wrong answer should not be here
        #| Eq[I["N"], (S, S)]
        Eq[I, (S["M"], S)]
        | Eq[I["M"], (S["M"], S)]
        | Eq[I["N"], (S["M"], S)]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(I_Sm_Sl_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[S["N"], (S["M"], S["L"])],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sm_Sl_Factor, source, eq, value)


# Asymmetric special cases


@plum.overload
def factor(
    eq: (
        Eq[O, (I, O)]
        | Eq[B, (I, B)]
        | Eq[O, (I["N"], O)]
        | Eq[B, (I["N"], B)]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(OnorBn_I_OnorBn_Factor, source, eq, value)


@plum.overload
def factor(
    eq: (
        Eq[O, (O, I)]
        | Eq[B, (B, I)]
        | Eq[O, (O, I["N"])]
        | Eq[B, (B, I["N"])]  # noqa: F821
    ),
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(OnorBn_OnorBn_I_Factor, source, eq, value)


@plum.overload
def factor(
    # groups: Eq[Sn, tuple[I, Sn]],
    eq: Eq[S, (I, S)] | Eq[S, (I["N"], S)],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_I_Sn_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[S, (S, I)] | Eq[S, (S, I["N"])],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sn_I_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[O, (S["M"], O)],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(On_S_On_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[O, (O, S["L"])],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(On_On_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[S, (S["M"], S)],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_S_Sn_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[S, (S, S["L"])],  # noqa: F821
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(Sn_Sn_S_Factor, source, eq, value)


@plum.overload
def factor(
    eq: Eq[S, (S, S)],
    *,
    source: FromSource,
    value: Value,
):
    return _factor_from_source(S_S_S_Factor, source, eq, value)


# @dispatch(precedence=-1)  # type: ignore[arg-type]
@plum.overload
def factor(
    eq: Eq,
    *,
    source: FromSource,
    value: Value,
):
    if source != get_args(ClassOnly)[0]:
        return ZeroFactor(eq)  # type: ignore[arg-type]
    return ZeroFactor


@dispatch
def factor(groups, *, source, value):  # type: ignore[arg-type]
    pass


# Factor classes


class ZeroFactor(FactorBase):
    eq: Eq

    def cov(self, shape: tuple | None = None, surrogate: bool = False) -> NDArray:
        if shape is None:
            if not surrogate:
                out = self.eq.cov_shape
            else:
                out = tuple(np.prod(dims) for dims in self.eq.stable_shape)
        else:
            ndim = len(shape)
            if ndim == 2:
                out = shape
            elif ndim == 3:
                n, m, l = shape
                out = (n, m * l)
            elif ndim == 4:
                n, m, l, k = shape
                out = (n * m, l * k)
            else:
                raise ValueError(
                    "Expected shape dim to be larger 1 and less or equal than 4"
                )
        return torch.zeros(out)

    def matvec(self, v: NDArray, transpose: bool = False):
        dtype = v.dtype
        shape0, shape1 = self.eq.shape
        if isinstance(shape0, int):
            shape0 = (shape0,)
        if isinstance(shape1, int):
            shape1 = (shape1,)
        if v.shape == shape1:
            return torch.zeros(shape0, dtype=dtype)
        elif v.shape == shape0:
            return torch.zeros(shape1, dtype=dtype)
        else:
            raise ValueError("Vector should match group shape")

    def min_dim(self):
        return 0


class OOrB_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn(())
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):
        n = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, n)
        weight = torch.trace(cov) / n
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):
        n = cls.group_dims(eq)
        assert params[0].shape[-1] == n and params[0].shape == params[1].shape

        weight = contract("i,i->", params[0], params[1]) / n
        return weight

    def cov(
        self, shape: tuple[int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n = self.group_dims(self.eq, surrogate) if shape is None else shape[0]

        vec = torch.ones(n) * self.weights
        return torch.diag(vec)

    def matvec(self, v: NDArray, transpose: bool = False):
        n = self.group_dims(self.eq)
        assert v.shape[0] == n
        return self.weights * v

    def min_dim(self) -> tuple[int,]:
        return (1,)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> int:
        n, m = eq.stable_shape if surrogate else eq.shape
        assert n == m
        return n


class S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):
        n = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, n)

        trace = torch.trace(cov)
        b = (torch.sum(cov) - trace) / (n - 1)
        a = (trace - b) / n
        factor_value = torch.tensor([a, b])
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params):
        n = cls.group_dims(eq)
        assert params[0].shape[-1] == n and params[1].shape == params[0].shape

        trace = contract("i,i->", params[0], params[1])
        b = (contract("i,j->", params[0], params[1]) - trace) / (n - 1)
        a = (trace - b) / n
        factor_value = torch.tensor([a, b])
        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n = self.group_dims(self.eq, surrogate) if shape is None else shape[0]
        vec = torch.ones(n) * self.weights[0]
        eye = torch.diag(vec)
        ones = torch.ones((n, n)) * self.weights[1] / n
        return eye + ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n = self.group_dims(self.eq)
        assert v.shape[0] == n

        eye = self.weights[0] * v
        ones = self.weights[1] * torch.sum(v) / n
        return eye + ones

    def min_dim(self) -> tuple[int,]:
        return (2,)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> int:
        n, m = eq.stable_shape if surrogate else eq.shape
        assert n == m
        return n


class Sn_Sm_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn(())
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):
        n, m = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m)
        dtype = cov.dtype
        sqrt_nm = torch.sqrt(torch.as_tensor(n * m, dtype=dtype))

        weight = torch.sum(cov) / sqrt_nm
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):
        n, m = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        assert p1.shape[-1] == n and p2.shape[-1] == m
        dtype = p1.dtype

        sqrt_nm = torch.sqrt(torch.as_tensor(n * m, dtype=dtype))
        weight = contract("i,j->", params[0], params[1]) / sqrt_nm
        return weight

    def cov(
        self, shape: tuple[int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m = self.group_dims(self.eq, surrogate) if shape is None else shape
        w = self.weights
        dtype = w.dtype
        sqrt_nm = torch.sqrt(torch.as_tensor(n * m, dtype=dtype))

        return torch.ones((n, m)) * self.weights / sqrt_nm

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_nm = torch.sqrt(torch.as_tensor(n * m, dtype=dtype))

        if not transpose:
            assert v.shape[0] == m
            ones = torch.ones((n,)) / sqrt_nm
            return self.weights * torch.sum(v) * ones
        else:
            assert v.shape[0] == n
            ones = torch.ones((m,)) / sqrt_nm
            return self.weights * torch.sum(v) * ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> int:
        return eq.stable_shape if surrogate else eq.shape


class I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):

        n, m = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((n, m))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m)

        weight = cov
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        assert params[0].shape[-1] == n and params[1].shape[-1] == m

        weight = contract("i,j->ij", params[0], params[1])
        return weight

    def cov(
        self, shape: tuple[int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        if shape is not None:
            assert shape == self.group_dims(self.eq, surrogate)
        return self.weights

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        w = self.weights
        if transpose:
            assert v.shape[0] == n
            return (w * v[:, None]).sum(dim=0)
        else:
            assert v.shape[0] == m
            return w @ v

    def min_dim(self) -> tuple[int,]:
        n = self.group_dims(self.eq)
        return (n,)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> int:
        n, m = eq.stable_shape if surrogate else eq.shape
        return n, m


class S_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):

        n, m = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        w = init_fn((m,))
        weight = w
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m)
        dtype = cov.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        weight = cov.sum(dim=0) / sqrt_n
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        assert p1.shape[-1] == n and p2.shape[-1] == m

        weight = contract("i,j->j", params[0], params[1]) / sqrt_n
        return weight

    def cov(
        self, shape: tuple[int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        w = self.weights
        dtype = w.dtype

        if shape is None:
            n, m = self.group_dims(self.eq, surrogate)
        else:
            n, m = shape
            assert m == self.group_dims(self.eq)[1]

        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))
        return torch.kron(torch.ones((n, 1)) / sqrt_n, w[None, :])

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        if transpose:
            assert v.shape[0] == n
            return self.weights * torch.sum(v) / sqrt_n

        assert v.shape[0] == m
        ones = torch.ones((n,)) / sqrt_n
        return torch.sum(self.weights * v) * ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> int:
        n, m = eq.stable_shape if surrogate else eq.shape
        return n, m


class I_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):

        n, m = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        w = init_fn((n,))
        weight = w
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m)
        dtype = cov.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        weight = cov.sum(dim=1) / sqrt_m
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-1] == n and p2.shape[-1] == m

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        weight = contract("i,j->i", p1, p2) / sqrt_m
        return weight

    def cov(
        self, shape: tuple[int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        w = self.weights
        dtype = w.dtype

        if shape is None:
            n, m = self.group_dims(self.eq, surrogate)
        else:
            n, m = shape
            assert n == self.group_dims(self.eq)[0]

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        return torch.kron(self.weights[:, None], torch.ones((1, m)) / sqrt_m)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        if transpose:
            assert v.shape[0] == n
            ones = torch.ones((m,)) / sqrt_m
            return torch.sum(self.weights * v) * ones
        else:
            assert v.shape[0] == m
            return self.weights * torch.sum(v) / sqrt_m

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> int:
        n, m = eq.stable_shape if surrogate else eq.shape
        return n, m


class I_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):

        n, m = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        w = init_fn((n * m, n * m))
        dtype = w.dtype

        weight = w @ w.T / torch.sqrt(torch.as_tensor(n * (n - 1), dtype=dtype))
        weight = weight.reshape(n, m, n, m)
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, n * m)

        weight = cov.reshape(n, m, n, m)
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]

        assert p1.shape == (n, m) and p2.shape == (n, m)

        weight = contract("ik,jl->ikjl", p1, p2)
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m = self.group_dims(self.eq, surrogate)
        if shape is not None:
            assert shape[:2] == shape[2:] == (n, m)
        d = n * m
        return self.weights.reshape(d, d)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        assert v.shape == (n, m)
        w = self.weights
        if transpose:
            result = contract("ikjl,ik->jl", w, v)
        else:
            result = contract("ikjl,jl->ik", w, v)
        return result

    def min_dim(self) -> tuple[int, int]:
        n, m = self.group_dims(self.eq)
        return n, m

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l and m == k
        return n, m


class OnOrBn_OmOrBm_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn(())
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)

        assert cov.shape == (n * m, n * m)
        factor_value = torch.trace(cov) / (n * m)
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape == params[1].shape

        weight = contract("ik,ik->", p1, p2) / (n * m)
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        if shape is None:
            n, m = self.group_dims(self.eq, surrogate)
        else:
            assert shape[:2] == shape[2:]
            n, m = shape[:2]
        vec = torch.ones(n * m) * self.weights
        return torch.diag(vec)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        assert v.shape == (n, m)

        return self.weights * v

    def min_dim(self) -> tuple[int, int]:
        return 1, 1

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l and m == k
        return n, m


class Bn_Bn_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((4,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n = cls.group_dims(eq, surrogate)
        n2 = n * n
        assert cov.shape == (n2, n2)
        cov = cov.reshape(n, n, n, n)

        b = torch.stack(
            [
                contract("ikik->", cov),
                contract("ikki->", cov),
                contract("iijj->", cov),
                contract("iiii->", cov),
            ]
        )

        C = as_tensor(cls.trace(n), cov)
        factor_value = stable_solve(C, b[:, None])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        assert p1.shape[-2:] == (n, n) and p2.shape == params[1].shape

        b = torch.stack(
            [
                contract("ik,ik->", p1, p2),
                contract("ik,ki->", p1, p2),
                contract("ii,jj->", p1, p2),
                contract("ii,ii->", p1, p2),
            ]
        )

        C = as_tensor(cls.trace(n), p1)
        factor_value = stable_solve(C, b[:, None])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        if shape is None:
            n = self.group_dims(self.eq, surrogate)
        else:
            n = shape[0]
            assert n == shape[1] == shape[2] == shape[3]
        w = self.weights
        factors = as_tensor(self.basis(n), w)
        weighted_basis = w[:, None, None] * factors
        cov = torch.sum(weighted_basis, dim=0)
        return cov

    def matvec(self, v: NDArray, transpose: bool = False):
        n = self.group_dims(self.eq)
        assert v.shape == (n, n)
        dtype = v.dtype

        eye = v
        k = v.T
        bb1 = torch.eye(n, dtype=dtype) * torch.trace(v)
        delta_ijk = torch.diag(torch.diag(v))

        w = self.weights
        factors = torch.stack([eye, k, bb1, delta_ijk])
        weigted_factors = w[:, None, None] * factors
        result = torch.sum(weigted_factors, dim=0)
        return result

    def min_dim(self) -> tuple[int, int]:
        return 2, 2

    @classmethod
    def basis(self, n) -> list[NDArray,]:
        eye = torch.eye(n * n)
        k = sm.commutation_matrix(n, n)
        bb1 = torch.outer(torch.eye(n).reshape(-1), torch.eye(n).reshape(-1))
        return torch.stack([eye, k, bb1, eye * k])

    @classmethod
    def trace(cls, n) -> NDArray:
        n2 = n * n
        C = torch.tensor(
            [
                [n2, n, n, n],
                [n, n2, n, n],
                [n, n, n2, n],
                [n, n, n, n],
            ]
        )
        return C

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l and m == k
        assert n == m
        return n


class On_On_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((3,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov, surrogate: bool = False):

        n = cls.group_dims(eq, surrogate)
        n2 = n * n
        assert cov.shape == (n2, n2)
        cov = cov.reshape(n, n, n, n)

        b = torch.stack(
            [
                contract("ikik->", cov),
                contract("ikki->", cov),
                contract("iijj->", cov),
            ]
        )

        C = as_tensor(cls.trace(n), cov)
        factor_value = stable_solve(C, b[:, None])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, n) and p1.shape == p2.shape

        b = torch.stack(
            [
                contract("ik,ik->", p1, p2),
                contract("ik,ki->", p1, p2),
                contract("ii,jj->", p1, p2),
            ]
        )

        C = as_tensor(cls.trace(n), p1)
        factor_value = stable_solve(C, b[:, None])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        if shape is None:
            n = self.group_dims(self.eq, surrogate)
        else:
            n = shape[0]
            assert n == shape[1] == shape[2] == shape[3]
        w = self.weights
        factors = as_tensor(self.basis(n), w)
        weighted_factors = w[:, None, None] * factors
        cov = torch.sum(weighted_factors, dim=0)
        return cov

    def matvec(self, v: NDArray, transpose: bool = False):
        n = self.group_dims(self.eq)
        assert v.shape == (n, n)
        dtype = v.dtype

        eye = v
        k = v.T
        bb1 = torch.eye(n, dtype=dtype) * torch.trace(v)

        w = self.weights
        factors = torch.stack([eye, k, bb1])
        weighted_factors = w[:, None, None] * factors
        matvec = torch.sum(weighted_factors, dim=0)
        return matvec

    def min_dim(self) -> tuple[int, int]:
        return 2, 2

    @classmethod
    def basis(cls, n) -> list[NDArray,]:
        eye = torch.eye(n * n)
        k = sm.commutation_matrix(n, n)
        bb1 = torch.outer(torch.eye(n).reshape(-1), torch.eye(n).reshape(-1))
        return torch.stack([eye, k, bb1])

    @classmethod
    def trace(cls, n) -> NDArray:
        C = torch.tensor(
            [
                [n * n, n, n],
                [n, n * n, n],
                [n, n, n * n],
            ],
        )
        return C

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l and m == k
        assert n == m
        return n


class OOrB_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, n * m)
        cov = cov.reshape(n, m, n, m)
        S = contract("ikil->", cov)
        T = contract("ikik->", cov)

        b = (S - T) / (n * (m - 1))
        a = T / (n * m) - b / m
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[0].shape == params[1].shape

        S = contract("ik,il->", params[0], params[1])
        T = contract("ik,ik->", params[0], params[1])

        b = (S - T) / (n * (m - 1))
        a = T / (n * m) - b / m
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        if shape is None:
            n, m = self.group_dims(self.eq, surrogate)
        else:
            assert shape[:2] == shape[2:]
            n, m = shape[:2]
        vec_nm = torch.ones(n * m) * self.weights[0]
        eye_nm = torch.diag(vec_nm)
        eye_kron_ones = iden_kron_ones(self.weights[1], n, m) / m
        return eye_nm + eye_kron_ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        assert v.shape == (n, m)

        eye_nm = self.weights[0] * v
        eye_kron_ones = self.weights[1] * torch.sum(v, dim=1, keepdim=True) / m
        return eye_nm + eye_kron_ones

    def min_dim(self) -> tuple[int, int]:
        return 1, 2  # ??? S is 2

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l and m == k
        # assert n == m
        return n, m


class S_OOrB_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, n * m)
        cov = cov.reshape(n, m, n, m)
        S = contract("ikjk->", cov)
        T = contract("ikik->", cov)

        b = (S - T) / (m * (n - 1))
        a = T / (n * m) - b / n
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[0]

        assert p1.shape[-2:] == (n, m) and p1.shape == p2.shape

        S = contract("ik,jk->", p1, p2)
        T = contract("ik,ik->", p1, p2)

        b = (S - T) / (m * (n - 1))
        a = T / (n * m) - b / n
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        if shape is None:
            n, m = self.group_dims(self.eq, surrogate)
        else:
            assert shape[:2] == shape[2:]
            n, m = shape[:2]
        w = self.weights
        vec_nm = torch.ones(n * m) * w[0]
        eye_nm = torch.diag(vec_nm)
        ones_kron_eye = ones_kron_iden(w[1], n, m) / n
        return eye_nm + ones_kron_eye

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        assert v.shape == (n, m)

        w = self.weights
        eye_nm = w[0] * v
        ones_kron_eye = w[1] * torch.sum(v, dim=0, keepdim=True) / n
        return eye_nm + ones_kron_eye

    def min_dim(self) -> tuple[int, int]:
        return 2, 1

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l and m == k
        return n, m


class S_OOrB_OOrB_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, m * n)
        cov = cov.reshape(n, m, m, n)
        S = contract("ikkl->", cov)
        T = contract("ikki->", cov)

        b = (S - T) / (m * (n - 1))
        a = T / (n * m) - b / n
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (m, n)

        S = contract("ik,kl->", params[0], params[1])
        T = contract("ik,ki->", params[0], params[1])

        b = (S - T) / (m * (n - 1))
        a = T / (n * m) - b / n
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        if shape is None:
            n, m = self.group_dims(self.eq, surrogate)
        else:
            assert shape[:2] == shape[2::-1]
            n, m = shape[:2]
        w = self.weights
        vec_nm = torch.ones(n * m) * w[0]
        eye_nm = torch.diag(vec_nm)
        ones_kron_eye = ones_kron_iden(w[1], n, m) / n
        return (eye_nm + ones_kron_eye) @ sm.commutation_matrix(n, m)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        w = self.weights

        if not transpose:
            assert v.shape == (m, n)
            eye_nm = w[0] * v.T
            ones_kron_eye = w[1] * torch.sum(v.T, dim=0, keepdim=True) / n
        else:
            assert v.shape == (n, m)
            eye_nm = w[0] * v.T
            ones_kron_eye = w[1] * torch.sum(v.T, dim=1, keepdim=True) / n
        return eye_nm + ones_kron_eye

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == k and m == l
        return n, m


class Sn_Sm_Sn_Sm_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((4,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)

        assert cov.shape == (n * m, n * m)
        cov = cov.reshape(n, m, n, m)

        S = contract("ikjl->", cov)
        T = contract("ikik->", cov)
        B = contract("ikil->", cov)
        D = contract("ikjk->", cov)

        b = (S - B - D + T) / ((n - 1) * (m - 1))
        c = (B - T) / (n * (m - 1)) - b / n
        d = (D - T) / (m * (n - 1)) - b / m
        a = T / (n * m) - b / (n * m) - c / m - d / n

        factor_value = torch.tensor([a, c, d, b])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[0].shape == params[1].shape

        S = contract("ik,jl->", params[0], params[1])
        T = contract("ik,ik->", params[0], params[1])
        B = contract("ik,il->", params[0], params[1])
        D = contract("ik,jk->", params[0], params[1])

        b = (S - B - D + T) / ((n - 1) * (m - 1))
        c = (B - T) / (n * (m - 1)) - b / n
        d = (D - T) / (m * (n - 1)) - b / m
        a = T / (n * m) - b / (n * m) - c / m - d / n

        factor_value = torch.tensor([a, c, d, b])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        if shape is None:
            n, m = self.group_dims(self.eq, surrogate)
        else:
            assert shape[:2] == shape[2:]
            n, m = shape[:2]

        w = self.weights

        vec_nm = torch.ones(n * m) * w[0]
        eye_nm = torch.diag(vec_nm)

        eye_kron_ones = iden_kron_ones(w[1], n, m) / m
        ones_kron_eye = ones_kron_iden(w[2], n, m) / n

        ones = w[3] * torch.ones((n * m, n * m)) / (n * m)
        return eye_nm + eye_kron_ones + ones_kron_eye + ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        assert v.shape == (n, m)
        dtype = v.dtype

        eye_nm = v
        ones = torch.ones((n, m), dtype=dtype) * torch.sum(v) / (n * m)
        eye_kron_ones = torch.sum(v, dim=1, keepdim=True) / m
        ones_kron_eye = torch.sum(v, dim=0, keepdim=True) / n

        w = self.weights
        matvec = (
            w[0] * eye_nm + w[1] * eye_kron_ones + w[2] * ones_kron_eye + w[3] * ones
        )
        return matvec

    def min_dim(self) -> tuple[int, int]:
        return 2, 2

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l and m == k
        # assert n != m
        return n, m


class Sn_Sm_Sm_Sn_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((4,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m = cls.group_dims(eq, surrogate)

        assert cov.shape == (n * m, m * n)
        cov = cov.reshape(n, m, m, n)

        S = contract("ikjl->", cov)
        T = contract("ikki->", cov)
        B = contract("ikji->", cov)
        D = contract("ikkl->", cov)

        b = (S - B - D + T) / ((n - 1) * (m - 1))
        c = (B - T) / (n * (m - 1)) - b / n
        d = (D - T) / (m * (n - 1)) - b / m
        a = T / (n * m) - b / (n * m) - c / m - d / n

        factor_value = torch.tensor([a, c, d, b])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (m, n)

        S = contract("ik,jl->", params[0], params[1])
        T = contract("ik,ki->", params[0], params[1])
        B = contract("ik,ji->", params[0], params[1])
        D = contract("ik,kl->", params[0], params[1])

        b = (S - B - D + T) / ((n - 1) * (m - 1))
        c = (B - T) / (n * (m - 1)) - b / n
        d = (D - T) / (m * (n - 1)) - b / m
        a = T / (n * m) - b / (n * m) - c / m - d / n

        factor_value = torch.tensor([a, c, d, b])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        if shape is None:
            n, m = self.group_dims(self.eq, surrogate)
        else:
            assert shape[:2] == shape[2:][::-1]
            n, m = shape[:2]
        comm = sm.commutation_matrix(n, m)
        w = self.weights

        K_nm = comm * w[0]
        eye_kron_ones_comm = (iden_kron_ones(w[1], n, m) / m) @ comm
        ones_kron_eye_comm = (ones_kron_iden(w[2], n, m) / n) @ comm
        ones = w[3] * torch.ones((n * m, n * m)) / (n * m)

        return K_nm + eye_kron_ones_comm + ones_kron_eye_comm + ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m = self.group_dims(self.eq)
        w = self.weights

        if not transpose:
            assert v.shape == (m, n)

            K_nm = v.T
            ones = torch.ones((n, m)) * torch.sum(v) / (n * m)
            eye_kron_ones_comm = torch.sum(v.T, dim=1, keepdim=True) / m
            ones_kron_eye_comm = torch.sum(v.T, dim=0, keepdim=True) / n

            matvec = (
                w[0] * K_nm
                + w[1] * eye_kron_ones_comm
                + w[2] * ones_kron_eye_comm
                + w[3] * ones
            )
        else:
            assert v.shape == (n, m)

            K_nm = v.T
            ones = torch.ones((m, n)) * torch.sum(v) / (n * m)
            eye_kron_ones_comm = torch.sum(v, dim=1, keepdim=True).T / m
            ones_kron_eye_comm = torch.sum(v, dim=0, keepdim=True).T / n

            matvec = (
                w[0] * K_nm
                + w[1] * eye_kron_ones_comm
                + w[2] * ones_kron_eye_comm
                + w[3] * ones
            )

        return matvec

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == k and m == l
        # assert n != m
        return n, m


class Sn_Sm_Sk_Sm_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)

        assert cov.shape == (n * m, l * k)
        cov = cov.reshape(n, m, l, k)

        S = contract("ikjl->", cov)
        D = contract("ikjk->", cov)

        dtype = cov.dtype

        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))
        b = (S - D) / (sqrt_nl * (m - 1))
        a = (D / sqrt_nl - b) / m

        factor_value = torch.tensor([a, b])
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        S = contract("ik,jl->", p1, p2)
        D = contract("ik,jk->", p1, p2)

        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))

        b = (S - D) / (sqrt_nl * (m - 1))
        a = (D / sqrt_nl - b) / m

        factor_value = torch.tensor([a, b])
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert m == k
        w = self.weights
        dtype = w.dtype
        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))

        ones_kron_iden = w[0] * torch.kron(torch.ones((n, l)) / sqrt_nl, torch.eye(m))
        ones = w[1] * torch.ones((n * m, l * k)) / (sqrt_nl * m)
        return ones_kron_iden + ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)
            ones = torch.ones((n, m)) * torch.sum(v) / (sqrt_nl * m)
        else:
            assert v.shape == (n, m)
            ones = torch.ones((l, k)) * torch.sum(v) / (sqrt_nl * m)

        w = self.weights
        ones_kron_iden = torch.sum(v, dim=0, keepdim=True) / sqrt_nl
        return w[0] * ones_kron_iden + w[1] * ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == k
        return n, m, l, k


class Sn_Sm_Sn_Sl_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)

        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l, k)

        S = contract("ikjl->", cov)
        D = contract("ikil->", cov)

        sqrt_mk = torch.sqrt(torch.as_tensor(m * k, dtype=dtype))
        b = (S - D) / (sqrt_mk * (n - 1))
        a = (D / sqrt_mk - b) / n

        factor_value = torch.tensor([a, b])
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):
        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        S = contract("ik,jl->", p1, p2)
        D = contract("ik,il->", p1, p2)

        sqrt_mk = torch.sqrt(torch.as_tensor(m * k, dtype=dtype))
        b = (S - D) / (sqrt_mk * (n - 1))
        a = (D / sqrt_mk - b) / n

        factor_value = torch.tensor([a, b])
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert n == l
        w = self.weights
        dtype = w.dtype
        sqrt_mk = torch.sqrt(torch.as_tensor(m * k, dtype=dtype))

        ones_kron_iden = w[0] * torch.kron(torch.eye(n), torch.ones((m, k)) / sqrt_mk)
        ones = w[1] * torch.ones((n * m, l * k)) / (sqrt_mk * n)
        return ones_kron_iden + ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_mk = torch.sqrt(torch.as_tensor(m * k, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)
            ones = torch.ones((n, m)) * torch.sum(v) / (sqrt_mk * n)
        else:
            assert v.shape == (n, m)
            ones = torch.ones((l, k)) * torch.sum(v) / (sqrt_mk * n)

        w = self.weights
        ones_kron_iden = torch.sum(v, dim=1, keepdim=True) / sqrt_mk

        return w[0] * ones_kron_iden + w[1] * ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l
        return n, m, l, k


class Sn_Sm_Sm_Sk_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)

        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype
        cov = cov.reshape(n, m, l, k)

        S = contract("ikjl->", cov)
        DK = contract("ikkl->", cov)

        sqrt_nk = torch.sqrt(torch.as_tensor(n * k, dtype=dtype))
        b = (S - DK) / (sqrt_nk * (m - 1))
        a = (DK / sqrt_nk - b) / m

        factor_value = torch.tensor([a, b])
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)

        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (l, k)

        S = contract("ik,jl->", p1, p2)
        DK = contract("ik,kl->", p1, p2)

        sqrt_nk = torch.sqrt(torch.as_tensor(n * k, dtype=dtype))
        b = (S - DK) / (sqrt_nk * (m - 1))
        a = (DK / sqrt_nk - b) / m

        factor_value = torch.tensor([a, b])
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert m == l
        w = self.weights
        dtype = w.dtype
        sqrt_nk = torch.sqrt(torch.as_tensor(n * k, dtype=dtype))

        ones_kron_iden_comm = (
            w[0]
            * torch.kron(torch.ones((n, k)) / sqrt_nk, torch.eye(m))
            @ sm.commutation_matrix(k, m)
        )
        ones = w[1] * torch.ones((n * m, l * k)) / (sqrt_nk * m)
        return ones_kron_iden_comm + ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_nk = torch.sqrt(torch.as_tensor(n * k, dtype=dtype))
        w = self.weights

        if not transpose:
            assert v.shape == (l, k)
            ones_kron_iden_comm = torch.sum(v.T, dim=0, keepdim=True) / sqrt_nk
            ones = torch.ones((n, m)) * torch.sum(v) / (sqrt_nk * m)
        else:
            assert v.shape == (n, m)
            ones_kron_iden_comm = torch.sum(v.T, dim=1, keepdim=True) / sqrt_nk
            ones = torch.ones((l, k)) * torch.sum(v) / (sqrt_nk * m)
        return w[0] * ones_kron_iden_comm + w[1] * ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == l
        return n, m, l, k


class Sn_Sn_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((15,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n = cls.group_dims(eq, surrogate)
        n2 = n * n
        assert cov.shape == (n2, n2)
        cov = cov.reshape(n, n, n, n)

        T = contract("ikik->", cov)
        S = contract("ikjl->", cov)
        K = contract("ikki->", cov)
        IS = contract("iijj->", cov)
        B = contract("ikil->", cov)
        D = contract("ikjk->", cov)
        BK = contract("ikji->", cov)
        DK = contract("ikkl->", cov)
        BI = contract("ikjj->", cov)
        ID = contract("iijl->", cov)
        BID = contract("iiil->", cov)
        BBK = contract("ikii->", cov)
        IDBK = contract("iiji->", cov)
        DKBI = contract("ikkk->", cov)
        A = contract("iiii->", cov)

        s = as_tensor(cls.scale(n), cov)
        C = as_tensor(cls.trace(n, s), cov)

        terms = [T, S, K, IS, B, D, BK, DK, BI, ID, BID, BBK, IDBK, DKBI, A]
        b = torch.stack(terms) / s
        factor_value = stable_solve(C, b[:, None])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]

        T = contract("ik,ik->", p1, p2)
        S = contract("ik,jl->", p1, p2)
        K = contract("ik,ki->", p1, p2)
        IS = contract("ii,jj->", p1, p2)
        B = contract("ik,il->", p1, p2)
        D = contract("ik,jk->", p1, p2)
        BK = contract("ik,ji->", p1, p2)
        DK = contract("ik,kl->", p1, p2)
        BI = contract("ik,jj->", p1, p2)
        ID = contract("ii,jl->", p1, p2)
        BID = contract("ii,il->", p1, p2)
        BBK = contract("ik,ii->", p1, p2)
        IDBK = contract("ii,ji->", p1, p2)
        DKBI = contract("ik,kk->", p1, p2)
        A = contract("ii,ii->", p1, p2)

        s = as_tensor(cls.scale(n), p1)
        C = as_tensor(cls.trace(n, s), p1)

        terms = [T, S, K, IS, B, D, BK, DK, BI, ID, BID, BBK, IDBK, DKBI, A]
        b = torch.stack(terms) / s
        factor_value = stable_solve(C, b[:, None])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n = self.group_dims(self.eq, surrogate) if shape is None else shape[0]
        if shape is not None:
            assert n == shape[1] == shape[2] == shape[3]

        w = self.weights

        factors = as_tensor(self.basis(n), w)
        s = as_tensor(self.scale(n), w)

        c = w / s
        weigthed_factors = c[:, None, None] * factors
        cov = torch.sum(weigthed_factors, dim=0)
        return cov

    def matvec(self, v: NDArray, transpose: bool = False):
        n = self.group_dims(self.eq)
        assert v.shape == (n, n)

        trace = torch.trace(v)
        sum_all = torch.sum(v)
        sum_col = torch.sum(v, dim=1)
        sum_row = torch.sum(v, dim=0)
        diag = torch.diag(v)

        eye = v
        k = v.T
        bb1 = torch.eye(n) * trace
        ones = torch.ones((n, n)) * sum_all

        delta_ij = sum_col[:, None]  # torch.sum(v, dim=1, keepdim=True)
        delta_kl = sum_row[None, :]  # torch.sum(v, dim=0, keepdim=True)
        delta_il = sum_row[:, None]  # torch.sum(v.T, dim=1, keepdim=True)
        delta_jk = sum_col[None, :]  # torch.sum(v.T, dim=0, keepdim=True)
        delta_jl = torch.ones((n, n)) * trace
        delta_ik = torch.eye(n) * sum_all

        delta_ij_delta_ik = torch.diag(sum_col)
        delta_ik_delta_il = torch.diag(sum_row)
        delta_ij_delta_il = diag[:, None]
        delta_jk_delta_jl = diag[None, :]
        delta_ij_delta_jk_delta_jl = torch.diag(diag)

        w = self.weights
        dtype = w.dtype
        s = torch.as_tensor(self.scale(n), dtype=dtype)

        c = w / s

        matvec = (
            c[0] * eye
            + c[1] * ones
            + c[2] * k
            + c[3] * bb1
            + c[4] * delta_ij
            + c[5] * delta_kl
            + c[6] * delta_il
            + c[7] * delta_jk
            + c[8] * delta_jl
            + c[9] * delta_ik
            + c[10] * delta_ij_delta_ik
            + c[11] * delta_ij_delta_il
            + c[12] * delta_ik_delta_il
            + c[13] * delta_jk_delta_jl
            + c[14] * delta_ij_delta_jk_delta_jl
        )

        return matvec

    def min_dim(self) -> tuple[int, int]:
        return 4, 4

    @classmethod
    def basis(cls, n) -> list[NDArray,]:
        eye = torch.eye(n * n)
        k = sm.commutation_matrix(n, n)
        bb1 = torch.outer(torch.eye(n).reshape(-1), torch.eye(n).reshape(-1))
        ones = torch.ones((n * n, n * n))
        delta_ij = torch.kron(torch.diag(torch.ones(n)), torch.ones((n, n)))
        delta_kl = torch.kron(torch.ones((n, n)), torch.diag(torch.ones(n)))
        delta_il = delta_ij @ k
        delta_jk = delta_kl @ k
        delta_jl = delta_ij @ bb1
        delta_ik = bb1 @ delta_kl
        delta_ij_delta_ik = delta_ij * delta_ik
        delta_ij_delta_il = delta_ij * delta_il
        delta_ik_delta_il = delta_ik * delta_il
        delta_jk_delta_jl = delta_jk * delta_jl
        delta_ij_delta_jk_delta_jl = delta_ij * delta_jk * delta_jl

        return torch.stack(
            [
                eye,
                ones,
                k,
                bb1,
                delta_ij,
                delta_kl,
                delta_il,
                delta_jk,
                delta_jl,
                delta_ik,
                delta_ij_delta_ik,
                delta_ij_delta_il,
                delta_ik_delta_il,
                delta_jk_delta_jl,
                delta_ij_delta_jk_delta_jl,
            ]
        )

    @classmethod
    def scale(cls, n) -> NDArray:
        return torch.as_tensor(
            [1.0, n * n, 1.0, n, n, n, n, n, n, n, 1.0, 1.0, 1.0, 1.0, 1.0]
            # [1.0, n * n, 1.0, 1.0, n, n, n, n, n, n, n**0.5, n**0.5, n**0.5, n**0.5, 1.0] 
            # use this if want to pass test factor_v1_vs_v2
        )

    @classmethod
    def trace(cls, n, s) -> NDArray:
        n2 = n * n
        n3 = n * n * n
        n4 = n * n * n * n

        C = as_tensor(
            torch.as_tensor(
                [
                    [n2, n2, n, n, n2, n2, n, n, n, n, n, n, n, n, n],
                    [n2, n4, n2, n2, n3, n3, n3, n3, n3, n3, n2, n2, n2, n2, n],
                    [n, n2, n2, n, n, n, n2, n2, n, n, n, n, n, n, n],
                    [n, n2, n, n2, n, n, n, n, n2, n2, n, n, n, n, n],
                    [n2, n3, n, n, n3, n2, n2, n2, n2, n2, n2, n2, n, n, n],
                    [n2, n3, n, n, n2, n3, n2, n2, n2, n2, n, n, n2, n2, n],
                    [n, n3, n2, n, n2, n2, n3, n2, n2, n2, n, n2, n2, n, n],
                    [n, n3, n2, n, n2, n2, n2, n3, n2, n2, n2, n, n, n2, n],
                    [n, n3, n, n2, n2, n2, n2, n2, n3, n2, n, n2, n, n2, n],
                    [n, n3, n, n2, n2, n2, n2, n2, n2, n3, n2, n, n2, n, n],
                    [n, n2, n, n, n2, n, n, n2, n, n2, n2, n, n, n, n],
                    [n, n2, n, n, n2, n, n2, n, n2, n, n, n2, n, n, n],
                    [n, n2, n, n, n, n2, n2, n, n, n2, n, n, n2, n, n],
                    [n, n2, n, n, n, n2, n, n2, n2, n, n, n, n, n2, n],
                    [n, n, n, n, n, n, n, n, n, n, n, n, n, n, n],
                ]
            ),
            s,
        )
        C /= torch.outer(s, s)
        return C

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == m and m == l and l == k
        return n


class I_OorB_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((n, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)

        assert cov.shape == (n * m, l * m)
        cov = cov.reshape(n, m, l, m)
        factor_value = contract("ikjk->ij", cov) / m

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (l, m)

        weight = contract("ik,jk->ij", params[0], params[1]) / (m)

        weight = weight
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape[:2]
        if shape is not None:
            assert shape[1] == shape[3]
            assert n == self.group_dims(self.eq)[0]

        w = self.weights
        assert w.shape == (n, l)
        return torch.kron(w, torch.eye(m))

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        if not transpose:
            assert v.shape == (l, m)
            return self.weights @ v
        else:
            assert v.shape == (n, m)
            return self.weights.T @ v

    def min_dim(self) -> tuple[int, int]:
        n, m, _ = self.group_dims(self.eq)
        return n, 1

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == k
        return n, m, l


class OorB_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((m, k))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, n * k)
        cov = cov.reshape(n, m, n, k)
        factor_value = contract("ikil->kl", cov) / n

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, k = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (n, k)

        factor_value = contract("ik,il->kl", params[0], params[1]) / (n)

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, k = self.group_dims(self.eq, surrogate) if shape is None else shape[:2]
        if shape is not None:
            assert shape[0] == shape[2]
            assert m == self.group_dims(self.eq)[1]
        w = self.weights
        assert w.shape == (m, k)
        return torch.kron(torch.eye(n), w)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, k = self.group_dims(self.eq)
        if not transpose:
            assert v.shape == (n, k)
            return v @ self.weights.T
        else:
            assert v.shape == (n, m)
            return v @ self.weights

    def min_dim(self) -> tuple[int, int]:
        n, m = self.group_dims(self.eq)
        return 1, m

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l
        return n, m, k


class I_S_I_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq, init_fn: InitFn = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2, n, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * m)
        cov = cov.reshape(n, m, l, m)
        S = contract("ikjl->ij", cov)
        T = contract("ikjk->ij", cov)

        v = (S - T) / (m - 1)
        w = (T - v) / m
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)

        p1 = params[0]
        p2 = params[1]
        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, m)

        S = contract("ik,jl->ij", p1, p2)
        T = contract("ik,jk->ij", p1, p2)

        v = (S - T) / (m - 1)
        w = (T - v) / m
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape[:2]
        if shape is not None:
            assert shape[1] == shape[3]
            assert n == self.group_dims(self.eq)[0]

        w = self.weights
        assert w.shape == (2, n, l)
        return torch.kron(w[0], torch.eye(m)) + torch.kron(w[1], torch.ones((m, m))) / m

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights

        if not transpose:
            assert v.shape == (l, m)
            eyes = w[0] @ v
            ones = w[1] @ v.sum(dim=1, keepdim=True) / m
        else:
            assert v.shape == (n, m)
            eyes = w[0].T @ v
            ones = w[1].T @ v.sum(dim=1, keepdim=True) / m
        return eyes + ones

    def min_dim(self) -> tuple[int, int]:
        n, m, _ = self.group_dims(self.eq)
        return n, 2

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == k
        return n, m, l


class S_I_S_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn = None):  # type: ignore[arg-type]

        n, m, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2, m, k))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, n * k)
        cov = cov.reshape(n, m, n, k)
        S = contract("ikjl->kl", cov)
        T = contract("ikil->kl", cov)

        v = (S - T) / (n - 1)
        w = (T - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, k = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (n, k)

        S = contract("ik,jl->kl", params[0], params[1])
        T = contract("ik,il->kl", params[0], params[1])

        v = (S - T) / (n - 1)
        w = (T - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, k = self.group_dims(self.eq, surrogate) if shape is None else shape[:2]
        if shape is not None:
            assert shape[0] == shape[2]
            assert m == self.group_dims(self.eq)[1]

        w = self.weights
        assert w.shape == (2, m, k)
        return torch.kron(torch.eye(n), w[0]) + torch.kron(torch.ones((n, n)), w[1]) / n

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, k = self.group_dims(self.eq)
        w = self.weights
        if not transpose:
            assert v.shape == (n, k)
            eyes = v @ w[0].T
            ones = v.sum(dim=0, keepdim=True) @ w[1].T / n
        else:
            assert v.shape == (n, m)
            eyes = v @ w[0]
            ones = v.sum(dim=0, keepdim=True) @ w[1] / n
        return eyes + ones

    def min_dim(self) -> tuple[int, int]:
        n, m = self.group_dims(self.eq)
        return 2, m

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l
        return n, m, k


class I_OnorBn_OnorBn_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((n, k))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, m * k)
        cov = cov.reshape(n, m, m, k)

        factor_value = contract("ikkj->ij", cov) / m
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (l, k)

        factor_value = contract("ik,kj->ij", params[0], params[1]) / (m)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert m == l
        assert n == self.group_dims(self.eq)[0]
        assert k == self.group_dims(self.eq)[-1]

        w = self.weights
        assert w.shape == (n, k)

        comm = sm.commutation_matrix(k, m)
        return torch.kron(w, torch.eye(m)) @ comm

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        if transpose:
            assert v.shape == (n, m)
            return v.T @ self.weights
        else:
            assert v.shape == (l, k)
            return self.weights @ v.T

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == l
        return n, m, l, k


class OnorBn_I_I_OnorBn_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((m, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * n)
        cov = cov.reshape(n, m, l, n)

        factor_value = contract("ikli->kl", cov) / n
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (l, k)

        factor_value = contract("ik,li->kl", params[0], params[1]) / (n)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert n == k
        assert m == self.group_dims(self.eq)[1]
        assert l == self.group_dims(self.eq)[2]

        w = self.weights
        assert w.shape == (m, l)

        comm = sm.commutation_matrix(n, l)
        return torch.kron(torch.eye(n), w) @ comm

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        if transpose:
            assert v.shape == (n, m)
            return (v @ self.weights).T
        else:
            assert v.shape == (l, k)
            return (self.weights @ v).T

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == k
        return n, m, l, k


class I_Sm_Sm_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2, n, k))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, m * k)
        cov = cov.reshape(n, m, m, k)

        T = contract("ikkj->ij", cov)
        S = contract("iklj->ij", cov)

        v = (S - T) / (m - 1)
        w = (T - v) / m
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (l, k)

        T = contract("ik,kj->ij", params[0], params[1])
        S = contract("ik,lj->ij", params[0], params[1])

        v = (S - T) / (m - 1)
        w = (T - v) / m
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert m == l
        assert n == self.group_dims(self.eq)[0]
        assert k == self.group_dims(self.eq)[-1]

        w = self.weights
        assert w.shape == (2, n, k)
        comm = sm.commutation_matrix(k, m)
        return (
            torch.kron(w[0], torch.eye(m)) + torch.kron(w[1], torch.ones((m, m))) / m
        ) @ comm

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        w = self.weights
        if not transpose:
            assert v.shape == (l, k)
            eyes = w[0] @ v.T
            ones = w[1] @ (v.T).sum(dim=1, keepdim=True) / m
            return eyes + ones
        else:
            assert v.shape == (n, m)
            eyes = v.T @ w[0]
            ones = (v.T).sum(dim=0, keepdim=True) @ w[1] / m
            return eyes + ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == l
        return n, m, l, k


class I_Sm_Sk_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((n, k))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l, k)
        sqrt_ml = torch.sqrt(torch.as_tensor(m * l, dtype=dtype))

        factor_value = contract("iklj->ij", cov) / sqrt_ml
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        sqrt_ml = torch.sqrt(torch.as_tensor(m * l, dtype=dtype))
        factor_value = contract("ik,lj->ij", p1, p2) / (sqrt_ml)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert n == self.group_dims(self.eq)[0]
        assert k == self.group_dims(self.eq)[-1]
        w = self.weights
        dtype = w.dtype
        sqrt_ml = torch.sqrt(torch.as_tensor(m * l, dtype=dtype))
        assert w.shape == (n, k)
        comm = sm.commutation_matrix(k, l)
        return torch.kron(w, torch.ones((m, l)) / sqrt_ml) @ comm

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        # assert v.shape == (l, k)  # TODO(bla): fixing transpose
        dtype = v.dtype
        sqrt_ml = torch.sqrt(torch.as_tensor(m * l, dtype=dtype))

        if transpose:
            w = self.weights
            reduce_kl = (v.T).sum(dim=0, keepdim=True) @ w
            pop_l = torch.tile(reduce_kl, (l, 1)) / sqrt_ml
            return pop_l
        else:
            w = self.weights
            reduce_kl = w @ (v.T).sum(dim=1, keepdim=True)
            pop_m = torch.tile(reduce_kl, (1, m)) / sqrt_ml
            return pop_m

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        return n, m, l, k


class Sn_I_I_Sn_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2, m, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * n)
        cov = cov.reshape(n, m, l, n)

        T = contract("ikli->kl", cov)
        S = contract("iklj->kl", cov)

        v = (S - T) / (n - 1)
        w = (T - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        assert params[0].shape[-2:] == (n, m) and params[1].shape[-2:] == (l, k)

        T = contract("ik,li->kl", params[0], params[1])
        S = contract("ik,lj->kl", params[0], params[1])

        v = (S - T) / (n - 1)
        w = (T - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert n == k
        assert m == self.group_dims(self.eq)[1]
        assert l == self.group_dims(self.eq)[2]

        w = self.weights

        assert w.shape == (2, m, l)
        comm = sm.commutation_matrix(n, l)
        return (
            torch.kron(torch.eye(n), w[0]) + torch.kron(torch.ones((n, n)), w[1]) / n
        ) @ comm

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        w = self.weights

        if not transpose:
            assert v.shape == (l, k)
            eyes = v.T @ w[0].T
            ones = (v.T).sum(dim=0, keepdim=True) @ w[1].T / n
        else:
            assert v.shape == (n, m)
            eyes = (v @ w[0]).T
            ones = w[1].T @ (v.T).sum(dim=1, keepdim=True) / n
        return eyes + ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == k
        return n, m, l, k


class Sn_I_Sk_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((m, k))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l, k)

        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))
        factor_value = contract("ikjl->kl", cov) / sqrt_nl
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))
        factor_value = contract("ik,jl->kl", p1, p2) / (sqrt_nl)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        w = self.weights
        dtype = w.dtype
        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))

        assert w.shape == (m, k)
        return torch.kron(torch.ones((n, l)) / sqrt_nl, w)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)

            reduce_kl = v.sum(dim=0, keepdim=True) @ w.T
            pop_n = torch.tile(reduce_kl, (n, 1)) / sqrt_nl
            return pop_n
        else:
            assert v.shape == (n, m)

            reduce_nm = v.sum(dim=0, keepdim=True) @ w
            pop_l = torch.tile(reduce_nm, (l, 1)) / sqrt_nl
            return pop_l

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        return n, m, l, k


class Sn_I_I_Sl_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((m, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l, k)

        sqrt_nk = torch.sqrt(torch.as_tensor(n * k, dtype=dtype))
        factor_value = contract("iklj->kl", cov) / sqrt_nk
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        sqrt_nk = torch.sqrt(torch.as_tensor(n * k, dtype=dtype))
        factor_value = contract("ik,lj->kl", p1, p2) / (sqrt_nk)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        w = self.weights
        dtype = w.dtype
        sqrt_nk = torch.sqrt(torch.as_tensor(n * k, dtype=dtype))

        assert w.shape == (m, l)
        comm = sm.commutation_matrix(k, l)
        return torch.kron(torch.ones((n, k)) / sqrt_nk, w) @ comm

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        # assert v.shape == (l, k)  # TODO(bla): shapes can be different
        dtype = v.dtype
        sqrt_nk = torch.sqrt(torch.as_tensor(n * k, dtype=dtype))

        if transpose:
            w = self.weights
            reduce_kl = (v.sum(dim=0, keepdim=True) @ w).T
            pop_k = torch.tile(reduce_kl, (1, k)) / sqrt_nk
            return pop_k
        else:
            w = self.weights
            reduce_kl = (v.T).sum(dim=0, keepdim=True) @ w.T
            pop_n = torch.tile(reduce_kl, (n, 1)) / sqrt_nk
            return pop_n

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        return n, m, l, k


class OOrB_S_OOrB_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((k,))
        return weight

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l, k)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        factor_value = contract("ikil->l", cov) / (sqrt_m * n)
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        factor_value = contract("ik,il->l", p1, p2) / (sqrt_m * n)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert n == l
        assert k == self.group_dims(self.eq)[-1]
        w = self.weights
        assert w.shape == (k,)
        dtype = w.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        eye_n = torch.eye(n)
        one_m = torch.ones((m, 1)) / sqrt_m
        return torch.kron(eye_n, torch.kron(one_m, w[None, :]))

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)
            reduce_k = v @ w[:, None]
            pop_m = torch.tile(reduce_k, (1, m)) / sqrt_m
            return pop_m
        else:
            assert v.shape == (n, m)
            reduce_m = v.sum(dim=1, keepdim=True) / sqrt_m
            pop_k = torch.kron(w[None, :], reduce_m)
            return pop_k

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l
        return n, m, l, k


class OOrB_S_I_OOrB_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((l,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l, k)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        factor_value = contract("ikji->j", cov) / (sqrt_m * n)
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        factor_value = contract("ik,ji->j", p1, p2) / (sqrt_m * n)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert n == k
        assert l == self.group_dims(self.eq)[2]
        w = self.weights
        dtype = w.dtype

        assert w.shape == (l,)

        eye_n = torch.eye(n)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        one_m = torch.ones((m, 1)) / sqrt_m
        return torch.kron(eye_n, torch.kron(one_m, w[None, :])) @ sm.commutation_matrix(
            n, l
        )

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)

        w = self.weights
        dtype = v.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)
            reduce_l = v.T @ w[:, None]
            pop_m = torch.tile(reduce_l, (1, m)) / sqrt_m
            return pop_m
        else:
            assert v.shape == (n, m)
            reduce_m = v.sum(dim=1, keepdim=True) / sqrt_m
            pop_l = torch.kron(w[:, None], reduce_m.T)
            return pop_l

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == k
        return n, m, l, k


class S_Sm_S_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2, k))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l, k)

        T = contract("ikil->l", cov)
        S = contract("ikjl->l", cov)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        v = (S - T) / (sqrt_m * (n - 1))
        w = (T / sqrt_m - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        T = contract("ik,il->l", p1, p2)
        S = contract("ik,jl->l", p1, p2)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        v = (S - T) / (sqrt_m * (n - 1))
        w = (T / sqrt_m - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert n == l
        assert k == self.group_dims(self.eq)[-1]
        w = self.weights
        dtype = w.dtype
        assert w.shape == (2, k)

        w0 = w[0:1]
        w1 = w[1:2]
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        eye_n = torch.eye(n)
        ones_m = torch.ones((m, 1)) / sqrt_m
        ones_n_n = torch.ones((n, n)) / n

        kron_ones_m_w0 = torch.kron(ones_m, w0)
        kron_ones_m_w1 = torch.kron(ones_m, w1)

        kron_eye = torch.kron(eye_n, kron_ones_m_w0)
        kron_ones = torch.kron(ones_n_n, kron_ones_m_w1)
        return kron_eye + kron_ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        w = self.weights
        if not transpose:
            assert v.shape == (l, k)
            reduce_k = v @ w[0:1, :].T
            pop_m = torch.tile(reduce_k, (1, m)) / sqrt_m

            reduce = torch.sum(v.sum(dim=0) * w[1])
            pop_nm = reduce * torch.ones((n, m)) / (sqrt_m * n)
            return pop_m + pop_nm
        else:
            assert v.shape == (n, m)
            reduce_m = v.sum(dim=1, keepdim=True) / sqrt_m
            pop_k = torch.kron(w[0:1], reduce_m)

            reduce = torch.sum(v) / (sqrt_m * n)
            pop_kl = reduce * torch.kron(torch.ones((l, 1)), w[1:2])
            return pop_k + pop_kl

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == l
        return n, m, l, k


class S_Sm_I_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l, k)

        T = contract("ikji->j", cov)
        S = contract("ikjl->j", cov)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        v = (S - T) / (sqrt_m * (n - 1))
        w = (T / sqrt_m - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        T = contract("ik,ji->j", p1, p2)
        S = contract("ik,jl->j", p1, p2)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        v = (S - T) / (sqrt_m * (n - 1))
        w = (T / sqrt_m - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert n == k
        assert l == self.group_dims(self.eq)[2]
        w = self.weights
        dtype = w.dtype

        assert w.shape == (2, l)

        w0 = w[0:1]
        w1 = w[1:2]
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        eye_n = torch.eye(n)
        ones_m = torch.ones((m, 1)) / sqrt_m
        ones_n_n = torch.ones((n, n)) / n

        kron_ones_m_w0 = torch.kron(ones_m, w0)
        kron_ones_m_w1 = torch.kron(ones_m, w1)

        kron_eye = torch.kron(eye_n, kron_ones_m_w0)
        kron_ones = torch.kron(kron_ones_m_w1, ones_n_n)
        comm = sm.commutation_matrix(n, l)

        return kron_eye @ comm + kron_ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)

        dtype = v.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        w = self.weights
        if not transpose:
            assert v.shape == (l, k)
            reduce_l = v.T @ w[0:1, :].T
            pop_m = torch.tile(reduce_l, (1, m)) / sqrt_m

            reduce = torch.sum(v.sum(dim=1) * w[1])
            pop_nm = reduce * torch.ones((n, m)) / (sqrt_m * n)
            return pop_m + pop_nm
        else:
            assert v.shape == (n, m)
            reduce_m = v.sum(dim=1, keepdim=True) / sqrt_m
            pop_k = torch.kron(w[0:1].T, reduce_m.T)

            reduce = torch.sum(v) / (sqrt_m * n)
            pop_kl = reduce * torch.kron(torch.ones((1, k)), w[1:2].T)
            return pop_k + pop_kl

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == k
        return n, m, l, k


class S_OOrB_OOrB_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((k,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        cov = cov.reshape(n, m, l, k)

        factor_value = contract("ikkl->l", cov) / (sqrt_n * m)
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        factor_value = contract("ik,kl->l", p1, p2) / (sqrt_n * m)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert m == l
        assert k == self.group_dims(self.eq)[-1]
        w = self.weights
        dtype = w.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        assert w.shape == (k,)

        eye_m = torch.eye(m)
        one_n = torch.ones((n, 1)) / sqrt_n
        return torch.kron(torch.kron(one_n, w[None, :]), eye_m) @ sm.commutation_matrix(
            k, m
        )

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        w = self.weights
        if not transpose:
            assert v.shape == (l, k)
            reduce_k = w[None, :] @ v.T
            pop_n = torch.tile(reduce_k, (n, 1)) / sqrt_n
            return pop_n
        else:
            assert v.shape == (n, m)
            reduce_n = v.sum(dim=0, keepdim=True) / sqrt_n
            pop_k = torch.kron(w[None, :], reduce_n.T)
            return pop_k

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == l
        return n, m, l, k


class S_OOrB_I_OOrB_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((l,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        cov = cov.reshape(n, m, l, k)

        factor_value = contract("ikjk->j", cov) / (sqrt_n * m)
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        factor_value = contract("ik,jk->j", p1, p2) / (sqrt_n * m)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert m == k
        assert l == self.group_dims(self.eq)[2]
        w = self.weights
        dtype = w.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        assert w.shape == (l,)

        eye_m = torch.eye(m)
        one_n = torch.ones((n, 1)) / sqrt_n
        return torch.kron(torch.kron(one_n, w[None, :]), eye_m)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)
            reduce_l = w[None, :] @ v
            pop_n = torch.tile(reduce_l, (n, 1)) / sqrt_n
            return pop_n
        else:
            assert v.shape == (n, m)
            reduce_n = v.sum(dim=0, keepdim=True) / sqrt_n
            pop_l = torch.kron(w[:, None], reduce_n)
            return pop_l

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == k
        return n, m, l, k


class Sn_S_S_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2, k))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l, k)

        T = contract("ikkl->l", cov)
        S = contract("ikjl->l", cov)

        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))
        v = (S - T) / (sqrt_n * (m - 1))
        w = (T / sqrt_n - v) / m
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        T = contract("ik,kl->l", p1, p2)
        S = contract("ik,jl->l", p1, p2)

        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))
        v = (S - T) / (sqrt_n * (m - 1))
        w = (T / sqrt_n - v) / m
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert m == l
        assert k == self.group_dims(self.eq)[-1]

        w = self.weights
        dtype = w.dtype

        assert w.shape == (2, k)

        w0 = w[0:1]
        w1 = w[1:2]
        eye_m = torch.eye(m)

        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))
        ones_n = torch.ones((n, 1)) / sqrt_n
        ones_m_m = torch.ones((m, m)) / m

        kron_ones_m_w0 = torch.kron(ones_n, w0)
        kron_ones_m_w1 = torch.kron(ones_n, w1)

        kron_eye = torch.kron(kron_ones_m_w0, eye_m)
        kron_ones = torch.kron(ones_m_m, kron_ones_m_w1)
        comm = sm.commutation_matrix(k, m)

        return kron_eye @ comm + kron_ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)
            reduce_k = w[0:1] @ v.T
            pop_n = torch.tile(reduce_k, (n, 1)) / sqrt_n

            reduce = torch.sum(v.sum(dim=0) * w[1])
            pop_nm = reduce * torch.ones((n, m)) / (sqrt_n * m)
            return pop_n + pop_nm
        else:
            assert v.shape == (n, m)
            reduce_n = v.sum(dim=0, keepdim=True) / sqrt_n
            pop_k = torch.kron(w[0:1], reduce_n.T)

            reduce = torch.sum(v) / (sqrt_n * m)
            pop_kl = reduce * torch.kron(torch.ones((l, 1)), w[1:2])
            return pop_k + pop_kl

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == l
        return n, m, l, k


class Sn_S_I_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((2, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        cov = cov.reshape(n, m, l, k)
        dtype = cov.dtype

        T = contract("ikjk->j", cov)
        S = contract("ikjl->j", cov)

        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))
        v = (S - T) / (sqrt_n * (m - 1))
        w = (T / sqrt_n - v) / m
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        T = contract("ik,jk->j", p1, p2)
        S = contract("ik,jl->j", p1, p2)

        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))
        v = (S - T) / (sqrt_n * (m - 1))
        w = (T / sqrt_n - v) / m
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert m == k
        assert l == self.group_dims(self.eq)[2]
        w = self.weights
        dtype = w.dtype

        assert w.shape == (2, l)

        w0 = w[0:1]
        w1 = w[1:2]
        eye_m = torch.eye(m)

        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))
        ones_n = torch.ones((n, 1)) / sqrt_n
        ones_m_m = torch.ones((m, m)) / m

        kron_ones_m_w0 = torch.kron(ones_n, w0)
        kron_ones_m_w1 = torch.kron(ones_n, w1)

        kron_eye = torch.kron(kron_ones_m_w0, eye_m)
        kron_ones = torch.kron(kron_ones_m_w1, ones_m_m)

        return kron_eye + kron_ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)
            reduce_l = w[0:1] @ v
            pop_n = torch.tile(reduce_l, (n, 1)) / sqrt_n

            reduce = torch.sum(v.sum(dim=1) * w[1])
            pop_nm = reduce * torch.ones((n, m)) / (sqrt_n * m)
            return pop_n + pop_nm
        else:
            assert v.shape == (n, m)
            reduce_n = v.sum(dim=0, keepdim=True) / sqrt_n
            pop_l = torch.kron(w[0:1].T, reduce_n)

            reduce = torch.sum(v) / (sqrt_n * m)
            pop_kl = reduce * torch.kron(torch.ones((1, k)), w[1:2].T)
            return pop_l + pop_kl

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert m == k
        return n, m, l, k


class S_S_S_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, k = cls.group_dims(eq)

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((5, k))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * n, n * k)
        cov = cov.reshape(n, n, n, k)

        b = torch.stack(
            [
                contract("ikil->l", cov),
                contract("iijl->l", cov),
                contract("ikkl->l", cov),
                contract("ikjl->l", cov),
                contract("iiil->l", cov),
            ]
        )

        scale = as_tensor(cls.scale(n), cov)
        C = as_tensor(cls.trace(n, scale), cov)

        factor_value = stable_solve(C, b / scale[:, None])
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]

        b = torch.stack(
            [
                contract("ik,il->l", p1, p2),
                contract("ii,jl->l", p1, p2),
                contract("ik,kl->l", p1, p2),
                contract("ik,jl->l", p1, p2),
                contract("ii,il->l", p1, p2),
            ]
        )

        scale = as_tensor(cls.scale(n), p1)
        C = as_tensor(cls.trace(n, scale), p1)

        factor_value = stable_solve(C, b / scale[:, None])
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n = self.group_dims(self.eq, surrogate)[0] if shape is None else shape[0]
        if shape is not None:
            assert n == shape[1] == shape[2]
            assert shape[-1] == self.group_dims(self.eq)[1]

        w = self.weights
        factors = as_tensor(self.basis(n), w)
        s = as_tensor(self.scale(n), w)

        cov = (
            torch.kron(factors[0], w[0, None]) / s[0]
            + torch.kron(factors[1], w[1, None]) / s[1]
            + torch.kron(factors[2], w[2, None]) / s[2]
            + torch.kron(factors[3], w[3, None]) / s[3]
            + torch.kron(factors[4], w[4, None]) / s[4]
        )

        return cov

    def matvec(self, v: NDArray, transpose: bool = False):
        n, k = self.group_dims(self.eq)

        eye = torch.eye(n)
        one = torch.ones((n, 1))
        ones = torch.ones((n, n))
        w = self.weights
        s = as_tensor(self.scale(n), w)
        inv_s = 1 / s

        if not transpose:
            assert v.shape == (n, k)
            matvec = (
                inv_s[0] * torch.kron((v @ w[0])[:, None], one.T)
                + inv_s[1] * torch.sum(v @ w[1]) * eye
                + inv_s[2] * torch.kron(one, (v @ w[2])[None, :])
                + inv_s[3] * torch.sum(v @ w[3]) * ones
                + inv_s[4] * torch.diag(v @ w[4])
            )
        else:
            assert v.shape == (n, n)
            matvec = (
                inv_s[0] * torch.kron(w[0:1], v.sum(dim=1, keepdim=True))
                + inv_s[1] * torch.kron(one, w[1:2]) * torch.trace(v)
                + inv_s[2] * torch.kron(w[2:3], v.sum(dim=0, keepdim=True).T)
                + inv_s[3] * torch.kron(one, w[3:4]) * torch.sum(v)
                + inv_s[4] * torch.kron(torch.diag(v)[:, None], w[4:5])
            )

        return matvec

    @classmethod
    def basis(self, n) -> NDArray:
        k = sm.commutation_matrix(n, n)
        bbI = torch.outer(torch.eye(n).reshape(-1), torch.eye(n).reshape(-1))
        ones = torch.ones((n, 1))

        all_ones = torch.kron(torch.ones((n, n)), ones)
        delta_jk = k @ torch.kron(torch.eye(n), ones)
        delta_ij = torch.kron(torch.eye(n), ones)
        delta_ik = bbI @ torch.kron(torch.eye(n), ones)
        delta_ij_delta_ik = delta_ij * delta_ik

        return torch.stack(
            [
                delta_ij,
                delta_ik,
                delta_jk,
                all_ones,
                delta_ij_delta_ik,
            ]
        )

    @classmethod
    def scale(cls, n, dtype=None) -> NDArray:
        return torch.as_tensor(
            [
                np.sqrt(n),
                np.sqrt(n),
                np.sqrt(n),
                n * np.sqrt(n),
                1.0,  # check this!
            ]
        )

    @classmethod
    def trace(cls, n, s) -> NDArray:
        n2 = n * n
        n3 = n * n * n

        C = as_tensor(
            torch.as_tensor(
                [
                    [n2, n, n, n2, n],
                    [n, n2, n, n2, n],
                    [n, n, n2, n2, n],
                    [n2, n2, n2, n3, n],
                    [n, n, n, n, n],
                ]
            ),
            s,
        )
        C /= torch.outer(s, s)
        return C

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == m and m == l
        return n, k


class S_S_I_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, l = cls.group_dims(eq)

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((5, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * n, l * n)
        cov = cov.reshape(n, n, l, n)

        b = torch.stack(
            [
                contract("ikli->l", cov),
                contract("iilj->l", cov),
                contract("iklk->l", cov),
                contract("iklj->l", cov),
                contract("iili->l", cov),
            ]
        )

        s = as_tensor(cls.scale(n), cov)
        C = as_tensor(cls.trace(n, s), cov)

        factor_value = stable_solve(C, b / s[:, None])
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, k = cls.group_dims(eq)

        p1 = params[0]
        p2 = params[1]

        b = torch.stack(
            [
                contract("ik,li->l", p1, p2),
                contract("ii,lj->l", p1, p2),
                contract("ik,lk->l", p1, p2),
                contract("ik,lj->l", p1, p2),
                contract("ii,li->l", p1, p2),
            ]
        )

        s = as_tensor(cls.scale(n), p1)
        C = as_tensor(cls.trace(n, s), p1)

        factor_value = stable_solve(C, b / s[:, None])
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, l = (
            self.group_dims(self.eq, surrogate)
            if shape is None
            else (shape[0], shape[2])
        )
        if shape is not None:
            assert n == shape[1] == shape[3]
            assert l == self.group_dims(self.eq)[1]

        w = self.weights
        s = as_tensor(self.scale(n), w)
        factors = as_tensor(self.basis(n), w)
        comm = as_tensor(sm.commutation_matrix(n, l), w)

        mat = (
            torch.kron(factors[0], w[0, None, :]) / s[0]
            + torch.kron(factors[1], w[1, None, :]) / s[1]
            + torch.kron(factors[2], w[2, None, :]) / s[2]
            + torch.kron(factors[3], w[3, None, :]) / s[3]
            + torch.kron(factors[4], w[4, None, :]) / s[4]
        )
        cov = mat @ comm

        return cov

    def matvec(self, v: NDArray, transpose: bool = False):
        n, l = self.group_dims(self.eq)

        eye = torch.eye(n)
        one = torch.ones((n, 1))
        ones = torch.ones((n, n))
        w = self.weights
        s = as_tensor(self.scale(n), w)
        inv_s = 1 / s

        if not transpose:
            assert v.shape == (l, n)
            v = v.T
            matvec = (
                inv_s[0] * torch.kron((v @ w[0])[:, None], one.T)
                + inv_s[1] * torch.sum(v @ w[1]) * eye
                + inv_s[2] * torch.kron(one, (v @ w[2])[None, :])
                + inv_s[3] * torch.sum(v @ w[3]) * ones
                + inv_s[4] * torch.diag(v @ w[4])
            )
        else:
            assert v.shape == (n, n)
            matvec = (
                inv_s[0] * torch.kron(w[0:1].T, v.sum(dim=1, keepdim=True).T)
                + inv_s[1] * torch.kron(one.T, w[1:2].T) * torch.trace(v)
                + inv_s[2] * torch.kron(w[2:3].T, v.sum(dim=0, keepdim=True))
                + inv_s[3] * torch.kron(one.T, w[3:4].T) * torch.sum(v)
                + inv_s[4] * torch.kron(torch.diag(v)[None, :], w[4:5].T)
            )

        return matvec

    @classmethod
    def basis(self, n) -> list[NDArray,]:
        k = sm.commutation_matrix(n, n)
        bbI = torch.outer(torch.eye(n).reshape(-1), torch.eye(n).reshape(-1))
        ones = torch.ones((n, 1))

        all_ones = torch.kron(torch.ones((n, n)), ones)
        delta_jk = k @ torch.kron(torch.eye(n), ones)
        delta_ij = torch.kron(torch.eye(n), ones)
        delta_ik = bbI @ torch.kron(torch.eye(n), ones)
        delta_ij_delta_ik = delta_ij * delta_ik

        return torch.stack(
            [
                delta_ij,
                delta_ik,
                delta_jk,
                all_ones,
                delta_ij_delta_ik,
            ]
        )

    @classmethod
    def scale(cls, n) -> NDArray:
        return torch.as_tensor(
            [
                np.sqrt(n),
                np.sqrt(n),
                np.sqrt(n),
                n * np.sqrt(n),
                1.0,
            ]
        )

    @classmethod
    def trace(cls, n, s) -> NDArray:
        n2 = n * n
        n3 = n * n * n

        C = as_tensor(
            torch.tensor(
                [
                    [n2, n, n, n2, n],
                    [n, n2, n, n2, n],
                    [n, n, n2, n2, n],
                    [n2, n2, n2, n3, n],
                    [n, n, n, n, n],
                ]
            ),
            s,
        )
        C /= torch.outer(s, s)
        return C

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        assert n == m and m == k
        return n, l


class Sn_Sm_Sk_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((k,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        cov = cov.reshape(n, m, l, k)
        dtype = cov.dtype

        sqrt_nml = torch.sqrt(torch.as_tensor(n * m * l, dtype=dtype))
        factor_value = torch.sum(cov, dim=(0, 1, 2)) / sqrt_nml
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        sqrt_nml = torch.sqrt(torch.as_tensor(n * m * l, dtype=dtype))
        factor_value = contract("ik,jl->l", p1, p2)
        factor_value = factor_value / (sqrt_nml)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert k == self.group_dims(self.eq)[-1]

        w = self.weights
        dtype = w.dtype

        assert w.shape == (k,)

        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        ones_nl = torch.ones((n, l)) / sqrt_nl
        ones_m = torch.ones((m, 1)) / sqrt_m
        ones = torch.kron(ones_nl, ones_m)

        return torch.kron(ones, w[None, :])

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_nlm = torch.sqrt(torch.as_tensor(n * l * m, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)
            ones = torch.ones((n, m)) / sqrt_nlm
            return torch.sum(v @ w) * ones
        else:
            assert v.shape == (n, m)
            ones = torch.ones((l, 1)) / sqrt_nlm
            return torch.sum(v) * torch.kron(ones, w[None, :])

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        return n, m, l, k


class Sn_Sm_I_Sl_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((l,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        cov = cov.reshape(n, m, l, k)
        dtype = cov.dtype

        sqrt_nmk = torch.sqrt(torch.as_tensor(n * m * k, dtype=dtype))
        factor_value = torch.sum(cov, dim=(0, 1, 3)) / sqrt_nmk
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype
        sqrt_nmk = torch.sqrt(torch.as_tensor(n * m * k, dtype=dtype))

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)

        factor_value = contract("ik,jl->j", params[0], params[1])
        factor_value = factor_value / (sqrt_nmk)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        assert l == self.group_dims(self.eq)[2]

        w = self.weights
        dtype = w.dtype
        assert w.shape == (l,)
        sqrt_mk = torch.sqrt(torch.as_tensor(m * k, dtype=dtype))

        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))
        ones_n = torch.ones((n, 1)) / sqrt_n
        ones_mk = torch.ones((m, k)) / sqrt_mk

        kron_ones_n_w = torch.kron(ones_n, w[None, :])

        return torch.kron(kron_ones_n_w, ones_mk)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_nmk = torch.sqrt(torch.as_tensor(n * m * k, dtype=dtype))

        if not transpose:
            assert v.shape == (l, k)
            ones = torch.ones((n, m)) / sqrt_nmk
            return torch.sum(v.T @ w) * ones
        else:
            assert v.shape == (n, m)
            ones = torch.ones((1, k)) / sqrt_nmk
            return torch.sum(v) * torch.kron(w[:, None], ones)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        return n, m, l, k


class Sn_Sm_Sk_Sl_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l, k = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn(())
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l, k = cls.group_dims(eq, surrogate)
        assert cov.shape == (n * m, l * k)
        cov = cov.reshape(n, m, l, k)
        dtype = cov.dtype
        sqrt_nmkl = torch.sqrt(torch.as_tensor(n * m * k * l, dtype=dtype))

        factor_value = torch.sum(cov) / sqrt_nmkl
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l, k = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-2:] == (n, m) and p2.shape[-2:] == (l, k)
        sqrt_nmkl = torch.sqrt(torch.as_tensor(n * m * k * l, dtype=dtype))

        factor_value = contract("ik,jl->", p1, p2)
        factor_value = factor_value / (sqrt_nmkl)
        weight = factor_value
        return weight

    def cov(
        self,
        shape: tuple[int, int, int, int] | None = None,
        surrogate: bool = False,
    ) -> NDArray:
        n, m, l, k = self.group_dims(self.eq, surrogate) if shape is None else shape
        w = self.weights
        dtype = w.dtype
        sqrt_nmkl = torch.sqrt(torch.as_tensor(n * m * k * l, dtype=dtype))

        ones = torch.ones((n * m, l * k))
        ones = ones / sqrt_nmkl

        return w * ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l, k = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_nmkl = torch.sqrt(torch.as_tensor(n * m * k * l, dtype=dtype))

        if transpose:
            assert v.shape == (n, m)
            ones = torch.ones((l, k)) / sqrt_nmkl
            return self.weights * torch.sum(v) * ones
        else:
            assert v.shape == (l, k)
            ones = torch.ones((n, m)) / sqrt_nmkl
            return self.weights * torch.sum(v) * ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int, int]:
        (n, m), (l, k) = eq.stable_shape if surrogate else eq.shape
        return n, m, l, k


class I_I_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((n, m))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l)
        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))
        factor_value = torch.sum(cov, dim=2) / sqrt_l

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)

        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))
        factor_value = contract("i,kl->ik", p1, p2) / (sqrt_l)

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert n, m == self.group_dims(self.eq)[:2]
        w = self.weights
        dtype = w.dtype

        assert w.shape == (n, m)
        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))
        return torch.kron(w, torch.ones(l)[None]) / sqrt_l

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            reduce_l = v.sum(dim=1) / sqrt_l
            return w @ reduce_l
        else:
            assert v.shape == (n,)
            ones_l = torch.ones((l,)) / sqrt_l
            return torch.outer(w.T @ v, ones_l)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        return n, m, l


class I_S_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((n, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l)
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        factor_value = torch.sum(cov, dim=1) / sqrt_m

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        factor_value = contract("i,kl->il", p1, p2) / (sqrt_m)

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert n == self.group_dims(self.eq)[0] and l == self.group_dims(self.eq)[2]
        w = self.weights
        dtype = w.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        assert w.shape == (n, l)
        return torch.kron(torch.ones(m)[None], w) / sqrt_m

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            reduce_m = v.sum(dim=0) / sqrt_m
            return w @ reduce_m
        else:
            assert v.shape == (n,)
            ones_m = torch.ones((m,)) / sqrt_m
            return torch.outer(ones_m, w.T @ v)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        return n, m, l


class S_I_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((m, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        cov = cov.reshape(n, m, l)
        factor_value = torch.sum(cov, dim=0) / sqrt_n

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)

        factor_value = contract("i,kl->kl", p1, p2) / (sqrt_n)

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert m, l == self.group_dims(self.eq)[1:]
        w = self.weights
        dtype = w.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        assert w.shape == (m, l)
        return torch.ones((n, 1)) @ w.reshape(-1)[None, :] / sqrt_n

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            ones = torch.ones((n,)) / sqrt_n
            return torch.sum(w * v) * ones
        else:
            assert v.shape == (n,)
            reduce = torch.sum(v) / sqrt_n
            return reduce * w

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        return n, m, l


class Sn_I_Sl_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((m,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype
        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))

        cov = cov.reshape(n, m, l)
        factor_value = torch.sum(cov, dim=(0, 2)) / sqrt_nl

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)

        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))
        factor_value = contract("i,kl->k", p1, p2) / (sqrt_nl)

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert m == self.group_dims(self.eq)[1]
        w = self.weights
        dtype = w.dtype
        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))

        assert w.shape == (m,)
        return torch.kron(w[None], torch.ones((n, l)) / sqrt_nl)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_nl = torch.sqrt(torch.as_tensor(n * l, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            ones = torch.ones((n,)) / sqrt_nl
            return ones * torch.sum(w * v.sum(dim=1))
        else:
            assert v.shape == (n,)
            ones = torch.ones((l,)) / sqrt_nl
            return torch.outer(w, torch.sum(v) * ones)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        return n, m, l


class Sn_Sm_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((l,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype
        sqrt_nm = torch.sqrt(torch.as_tensor(n * m, dtype=dtype))

        cov = cov.reshape(n, m, l)
        factor_value = torch.sum(cov, dim=(0, 1)) / sqrt_nm

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)

        sqrt_nm = torch.sqrt(torch.as_tensor(n * m, dtype=dtype))
        factor_value = contract("i,kl->l", p1, p2) / sqrt_nm

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert l == self.group_dims(self.eq)[2]
        w = self.weights
        dtype = w.dtype
        sqrt_nm = torch.sqrt(torch.as_tensor(n * m, dtype=dtype))

        assert w.shape == (l,)
        return torch.kron(torch.ones((n, m)) / sqrt_nm, w[None])

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_nm = torch.sqrt(torch.as_tensor(n * m, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            ones = torch.ones((n,)) / sqrt_nm
            return ones * torch.sum(w * v.sum(dim=0))
        else:
            assert v.shape == (n,)
            ones = torch.ones((m,)) / sqrt_nm
            return torch.outer(torch.sum(v) * ones, w)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        return n, m, l


class I_Sm_Sl_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((n,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype
        sqrt_ml = torch.sqrt(torch.as_tensor(m * l, dtype=dtype))

        cov = cov.reshape(n, m, l)
        factor_value = torch.sum(cov, dim=(1, 2)) / sqrt_ml

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)

        sqrt_ml = torch.sqrt(torch.as_tensor(m * l, dtype=dtype))
        factor_value = contract("i,kl->i", p1, p2) / (sqrt_ml)

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert n == self.group_dims(self.eq)[0]
        w = self.weights
        dtype = w.dtype
        sqrt_ml = torch.sqrt(torch.as_tensor(m * l, dtype=dtype))

        assert w.shape == (n,)
        return w[:, None] @ torch.ones(m * l)[None] / sqrt_ml

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_ml = torch.sqrt(torch.as_tensor(m * l, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            reduce = torch.sum(v) / sqrt_ml
            return reduce * w
        else:
            assert v.shape == (n,)
            ones = torch.ones((m, l)) / sqrt_ml
            return torch.sum(v * w) * ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        return n, m, l


class Sn_Sm_Sl_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn(())
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype

        sqrt_nml = torch.sqrt(torch.as_tensor(n * m * l, dtype=dtype))

        cov = cov.reshape(n, m, l)
        factor_value = torch.sum(cov) / sqrt_nml

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype
        sqrt_nml = torch.sqrt(torch.as_tensor(n * m * l, dtype=dtype))

        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)

        factor_value = contract("i,kl->", p1, p2) / sqrt_nml

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        w = self.weights
        dtype = w.dtype

        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        sqrt_nml = torch.sqrt(torch.as_tensor(n * m * l, dtype=dtype))
        return w * torch.ones((n, m * l)) / sqrt_nml

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        dtype = v.dtype
        sqrt_ml = torch.sqrt(torch.as_tensor(m * l, dtype=dtype))
        sqrt_n = torch.sqrt(torch.as_tensor(n, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            reduce = torch.sum(v) / sqrt_ml
            ones = torch.ones((n,)) / sqrt_n
            return self.weights * reduce * ones
        else:
            assert v.shape == (n,)
            reduce = torch.sum(v) / sqrt_n
            ones = torch.ones((m, l)) / sqrt_ml
            return self.weights * reduce * ones

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        return n, m, l


class OnorBn_I_OnorBn_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((m,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        cov = cov.reshape(n, m, l)
        cov = contract("kik->i", cov)
        factor_value = cov / n

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        assert params[0].shape[-1] == n and params[1].shape[-2:] == (m, l)

        factor_value = contract("k,ik->i", params[0], params[1]) / (n)

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert n == l
            assert m == self.group_dims(self.eq)[1]
        w = self.weights
        dtype = w.dtype

        assert w.shape == (m,)
        return torch.kron(w[None], torch.eye(n))

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        if not transpose:
            assert v.shape == (m, l)
            return v.T @ w
        else:
            assert v.shape == (n,)
            return torch.outer(w, v)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        assert n == l
        return n, m, l


class OnorBn_OnorBn_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((l,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        cov = cov.reshape(n, m, l)
        cov = contract("kki->i", cov)
        factor_value = cov / n

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        assert params[0].shape[-1] == n and params[1].shape[-2:] == (m, l)

        factor_value = contract("k,ki->i", params[0], params[1]) / (n)

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert n == m
            assert l == self.group_dims(self.eq)[2]
        w = self.weights
        dtype = w.dtype

        assert w.shape == (l,)
        return torch.kron(torch.eye(n), w[None])

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        if not transpose:
            assert v.shape == (m, l)
            return v @ w
        else:
            assert v.shape == (n,)
            return torch.outer(v, w)

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        assert n == m
        return n, m, l


class Sn_I_Sn_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((2, m))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        cov = cov.reshape(n, m, l)

        cov_kk = contract("kik->i", cov)
        cov_r = torch.sum(cov, dim=(0, 2))
        v = (cov_r - cov_kk) / (n - 1)
        w = (cov_kk - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        assert params[0].shape[-1] == n and params[1].shape[-2:] == (m, l)

        cov_kk = contract("k,ik->i", params[0], params[1])
        cov_r = contract("k,ij->i", params[0], params[1])
        v = (cov_r - cov_kk) / (n - 1)
        w = (cov_kk - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert n == l
            assert m == self.group_dims(self.eq)[1]
        w = self.weights
        dtype = w.dtype

        assert w.shape == (2, m)
        return (
            torch.kron(w[:1], torch.eye(n)) + torch.kron(w[1:2], torch.ones((n, n))) / n
        )

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        if not transpose:
            assert v.shape == (m, l)
            kron_eye = v.T @ w[0]
            kron_one = torch.sum(v.sum(dim=1) * w[1]) * torch.ones((n,)) / n
            return kron_eye + kron_one
        else:
            assert v.shape == (n,)
            kron_eye = torch.outer(w[0], v)
            kron_one = torch.outer(w[1], torch.sum(v) * torch.ones((n,))) / n
            return kron_eye + kron_one

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        assert n == l
        return n, m, l


class Sn_Sn_I_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((2, l))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        cov = cov.reshape(n, m, l)

        cov_kk = contract("kki->i", cov)
        cov_r = torch.sum(cov, dim=(0, 1))
        v = (cov_r - cov_kk) / (n - 1)
        w = (cov_kk - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)

        cov_kk = contract("k,ki->i", p1, p2)
        cov_r = contract("k,ji->i", p1, p2)
        v = (cov_r - cov_kk) / (n - 1)
        w = (cov_kk - v) / n
        factor_value = torch.stack([w, v])

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert n == m
            assert l == self.group_dims(self.eq)[2]
        w = self.weights
        dtype = w.dtype

        assert w.shape == (2, l)
        return (
            torch.kron(torch.eye(n), w[:1]) + torch.kron(torch.ones((n, n)), w[1:2]) / n
        )

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        if not transpose:
            assert v.shape == (m, l)
            kron_eye = v @ w[0]
            kron_one = torch.sum(v.sum(dim=0) * w[1]) * torch.ones((n,)) / n
            return kron_eye + kron_one
        else:
            assert v.shape == (n,)
            kron_eye = torch.outer(v, w[0])
            kron_one = torch.outer(torch.sum(v) * torch.ones((n,)), w[1]) / n
            return kron_eye + kron_one

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        assert n == m
        return n, m, l


class S_S_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n = cls.group_dims(eq)

        init_fn = torch.zeros if init_fn is None else init_fn
        weight = init_fn((5,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, n * n)
        cov = cov.reshape(n, n, n)

        b = torch.stack(
            [
                contract("iik->", cov),
                contract("ijj->", cov),
                contract("iji->", cov),
                contract("ijk->", cov),
                contract("iii->", cov),
            ]
        )

        scale = as_tensor(cls.scale(n), cov)
        C = as_tensor(cls.trace(n, scale), cov)

        factor_value = stable_solve(C, (b / scale)[:, None])
        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n = cls.group_dims(eq)
        assert params[0].shape[-1] == n and params[1].shape[-2:] == (n, n)

        p1 = params[0]
        p2 = params[1]

        b = torch.stack(
            [
                contract("i,ik->", p1, p2),
                contract("i,jj->", p1, p2),
                contract("i,ji->", p1, p2),
                contract("i,jk->", p1, p2),
                contract("i,ii->", p1, p2),
            ]
        )

        scale = as_tensor(cls.scale(n), p1)
        C = as_tensor(cls.trace(n, scale), p1)

        factor_value = stable_solve(C, (b / scale)[:, None])
        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n = self.group_dims(self.eq, surrogate) if shape is None else shape[0]
        if shape is not None:
            assert n == shape[1] == shape[2]

        w = self.weights
        dtype = w.dtype

        factors = as_tensor(self.basis(n), w)
        s = as_tensor(self.scale(n), w)
        c = w / s
        weighted_factors = c[:, None, None] * factors
        cov = torch.sum(weighted_factors, dim=0)
        return cov

    def matvec(self, v: NDArray, transpose: bool = False):
        n = self.group_dims(self.eq)

        w = self.weights
        s = as_tensor(self.scale(n), w)
        c = w / s

        one = torch.ones((n,))
        if not transpose:
            assert v.shape == (n, n)
            c = c[:, None]
            factors = [
                v.sum(dim=1),
                one * torch.trace(v),
                v.sum(dim=0),
                one * torch.sum(v),
                torch.diag(v),
            ]
        else:
            assert v.shape == (n,)
            c = c[:, None, None]
            factors = [
                torch.outer(v, one),
                torch.eye(n) * torch.sum(v),
                torch.outer(one, v),
                torch.ones((n, n)) * torch.sum(v),
                torch.diag(v),
            ]

        weighted_factors = c * torch.stack(factors)
        matvec = torch.sum(weighted_factors, dim=0)
        return matvec

    @classmethod
    def basis(self, n) -> list[NDArray,]:
        k = sm.commutation_matrix(n, n)
        bbI = torch.outer(torch.eye(n).reshape(-1), torch.eye(n).reshape(-1))
        ones = torch.ones((1, n))

        all_ones = torch.kron(ones, torch.ones((n, n)))
        delta_ij = torch.kron(ones, torch.eye(n)) @ k
        delta_jk = torch.kron(ones, torch.eye(n))
        delta_ik = torch.kron(ones, torch.eye(n)) @ bbI
        delta_ij_delta_ik = delta_ij * delta_ik

        return torch.stack(
            [
                delta_ij,
                delta_ik,
                delta_jk,
                all_ones,
                delta_ij_delta_ik,
            ]
        )

    @classmethod
    def scale(cls, n) -> NDArray:
        return torch.as_tensor(
            [
                np.sqrt(n),
                np.sqrt(n),
                np.sqrt(n),
                n * np.sqrt(n),
                1.0,
            ]
        )

    @classmethod
    def trace(cls, n, s) -> NDArray:
        n3 = n * n * n
        n2 = n * n
        C = as_tensor(
            torch.tensor(
                [
                    [n2, n, n, n2, n],
                    [n, n2, n, n2, n],
                    [n, n, n2, n2, n],
                    [n2, n2, n2, n3, n],
                    [n, n, n, n, n],
                ]
            ),
            s,
        )
        C /= torch.outer(s, s)
        return C

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int]:
        # TODO(bla): return `(n, m)`?
        # if surrogate:
        #     n, (m, l) = eq.stable_shape
        # else:

        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        assert n == m and n == l

        return n


class Sn_Sn_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((2,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l)
        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))

        cov_kk = contract("kki->i", cov)
        cov_r = torch.sum(cov, dim=(0, 1))
        b = torch.sum(cov_r - cov_kk) / ((n - 1) * sqrt_l)
        a = (torch.sum(cov_kk) / sqrt_l - b) / n
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype

        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)
        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))

        cov_kk = contract("k,ki->i", p1, p2)
        cov_r = contract("k,ji->i", p1, p2)
        b = torch.sum(cov_r - cov_kk) / ((n - 1) * sqrt_l)
        a = (torch.sum(cov_kk) / sqrt_l - b) / n
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert n == m
        w = self.weights
        dtype = w.dtype

        assert w.shape == (2,)
        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))
        vec = torch.ones((1, l)) * w[0] / sqrt_l
        ones = torch.ones((n, n * l)) * w[1] / (n * sqrt_l)
        return torch.kron(torch.eye(n), vec) + ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))

        a = sqrt_l
        b = n * sqrt_l

        if not transpose:
            assert v.shape == (m, l)
            kron_eye = w[0] * v.sum(dim=1) / a
            kron_one = w[1] * torch.sum(v) * torch.ones((n,)) / b
            return kron_eye + kron_one
        else:
            assert v.shape == (n,)
            kron_eye = w[0] * torch.outer(v, torch.ones((l,))) / a
            kron_one = w[1] * torch.sum(v) * torch.ones((m, l)) / b
            return kron_eye + kron_one

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        assert n == m
        return n, m, l


class Sn_S_Sn_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn((2,))
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        cov = cov.reshape(n, m, l)
        cov_kk = contract("kik->i", cov)
        cov_r = torch.sum(cov, dim=(0, 2))
        b = torch.sum(cov_r - cov_kk) / ((n - 1) * sqrt_m)
        a = (torch.sum(cov_kk) / sqrt_m - b) / n
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        p1 = params[0]
        p2 = params[1]
        dtype = p1.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        assert p1.shape[-1] == n and p2.shape[-2:] == (m, l)

        cov_kk = contract("k,ik->i", p1, p2)
        cov_r = contract("k,ij->i", p1, p2)
        b = torch.sum(cov_r - cov_kk) / ((n - 1) * sqrt_m)
        a = (torch.sum(cov_kk) / sqrt_m - b) / n
        factor_value = torch.tensor([a, b])

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        if shape is not None:
            assert n == l
        w = self.weights
        dtype = w.dtype

        assert w.shape == (2,)
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        vec = torch.ones((1, m)) * w[0] / sqrt_m
        ones = torch.ones((n, n * m)) * w[1] / (n * sqrt_m)
        return torch.kron(vec, torch.eye(n)) + ones

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            kron_eye = w[0] * v.sum(dim=0) / sqrt_m
            kron_one = w[1] * torch.sum(v) * torch.ones((n,)) / (n * sqrt_m)
            return kron_eye + kron_one
        else:
            assert v.shape == (n,)
            kron_eye = w[0] * torch.outer(torch.ones((m,)), v) / sqrt_m
            kron_one = w[1] * torch.sum(v) * torch.ones((m, l)) / (n * sqrt_m)
            return kron_eye + kron_one

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        assert n == l
        return n, m, l


class On_On_S_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn(())
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        cov = cov.reshape(n, m, l)
        dtype = cov.dtype

        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))
        factor_value = contract("iik->", cov) / (n * sqrt_l)

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        assert params[0].shape[-1] == n and params[1].shape[-2:] == (m, l)

        factor_value = contract("i,ik->", params[0], params[1])

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        w = self.weights
        dtype = w.dtype

        if shape is not None:
            assert n == m

        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))
        one = torch.ones((1, l)) / sqrt_l
        eye = torch.eye(n)
        return w * torch.kron(eye, one)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_l = torch.sqrt(torch.as_tensor(l, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            return w * v.sum(dim=1) / sqrt_l
        else:
            assert v.shape == (n,)
            return w * torch.outer(v, torch.ones((l,))) / sqrt_l

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        assert n == m
        return n, m, l


class On_S_On_Factor(Factor):
    eq: Eq
    weights: torch.Tensor

    @classmethod
    def from_init_fn(cls, eq: Eq, init_fn: InitFn | None = None):  # type: ignore[arg-type]

        n, m, l = cls.group_dims(eq)
        init_fn = torch.zeros if init_fn is None else init_fn

        weight = init_fn(())
        return cls(eq, weight)

    @classmethod
    def cov_estimate(cls, eq: Eq, cov: NDArray, surrogate: bool = False):

        n, m, l = cls.group_dims(eq, surrogate)
        assert cov.shape == (n, m * l)
        dtype = cov.dtype

        cov = cov.reshape(n, m, l)
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        factor_value = contract("iji->", cov) / (n * sqrt_m)

        weight = factor_value
        return weight

    @classmethod
    def outer_estimate(cls, eq: Eq, params: tuple):

        n, m, l = cls.group_dims(eq)
        assert params[0].shape[-1] == n and params[1].shape[-2:] == (m, l)

        factor_value = contract("i,ji->", params[0], params[1])

        weight = factor_value
        return weight

    def cov(
        self, shape: tuple[int, int, int] | None = None, surrogate: bool = False
    ) -> NDArray:
        n, m, l = self.group_dims(self.eq, surrogate) if shape is None else shape
        w = self.weights
        dtype = w.dtype

        if shape is not None:
            assert n == l

        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))
        one = torch.ones((1, m)) / sqrt_m
        eye = torch.eye(n)
        return w * torch.kron(one, eye)

    def matvec(self, v: NDArray, transpose: bool = False):
        n, m, l = self.group_dims(self.eq)
        w = self.weights
        dtype = v.dtype
        sqrt_m = torch.sqrt(torch.as_tensor(m, dtype=dtype))

        if not transpose:
            assert v.shape == (m, l)
            return w * v.sum(dim=0) / sqrt_m
        else:
            assert v.shape == (n,)
            return w * torch.outer(torch.ones((m,)), v) / sqrt_m

    @classmethod
    def group_dims(cls, eq: Eq, surrogate: bool = False) -> tuple[int, int, int]:
        n, (m, l) = eq.stable_shape if surrogate else eq.shape
        assert n == l
        return n, m, l


# Auxilary function


def stable_solve(A, b):
    D = 1 / torch.sqrt(torch.diag(A))
    b = b * D[:, None]
    A = torch.einsum("k,kl,l->kl", D, A, D)
    x = torch.linalg.solve(A, b)
    return torch.squeeze(x * D[:, None])


def flatten(t):
    for x in t:
        if isinstance(x, (tuple, list)):
            yield from flatten(x)
        else:
            yield x


def nested_order(seq, nest):
    if isinstance(nest, (list, tuple)):
        out = []
        for sub in nest:
            val, seq = nested_order(seq, sub)
            out.append(val)
        return tuple(out), seq
    return seq[0], seq[1:]


def iden_kron_ones(scale: float | torch.Tensor, n: int, m: int):
    vec_n = torch.ones(n) * scale
    diag_n = torch.diag(vec_n)
    return torch.kron(diag_n, torch.ones((m, m)))


def ones_kron_iden(scale: float | torch.Tensor, n: int, m: int):
    vec_m = torch.ones(m) * scale
    diag_m = torch.diag(vec_m)
    return torch.kron(torch.ones((n, n)), diag_m)


def cov_shape(groups: Eq):
    group_dims = groups.group_dims
    type_parameters = plum.type_parameter(groups)
    if not isinstance(type_parameters, tuple):
        type_parameters = (type_parameters,)

    nested_dims, _ = nested_order(group_dims, type_parameters)
    shape = tuple(np.prod(dims) for dims in nested_dims)
    return shape


def as_tensor(tensor, like: torch.Tensor) -> torch.Tensor:
    return torch.as_tensor(tensor, dtype=like.dtype, device=like.device)
