from __future__ import annotations

from .factorized import CustomFactorizedCovariance


class CustomLowRankCovariance(CustomFactorizedCovariance):
    r"""Covariance of a Gaussian parameter with low-rank structure.

    Assumes the covariance is factorized as a product of a matrix
    :math:\mathbf{S} \in \mathbb{R}^{P \times R}` and its transpose.

    ..math::
        \mathbf{\Sigma} = \mathbf{S} \mathbf{S}^\top


    :param rank: Rank of the covariance matrix. If `None`, the rank is set to the total number
        of mean parameters.
    """

    def __init__(
        self,
        rank,
        scale_init_weight,
        scale_init_bias,
        scale_lr_weight,
        scale_lr_bias,
        scale_forward_weight,
        scale_forward_bias,
    ):
        super().__init__(
            rank = rank,
            scale_init_weight = scale_init_weight,
            scale_init_bias = scale_init_bias,
            scale_lr_weight = scale_lr_weight,
            scale_lr_bias = scale_lr_bias,
            scale_forward_weight = scale_forward_weight,
            scale_forward_bias = scale_forward_bias,
        )
        self.__class__.__name__ = "LowRankCovariance"
