from __future__ import annotations

import math
from typing import TYPE_CHECKING

import numpy as np
import torch
from torch import nn

from inferno.bnn.params.covariances.factorized import FactorizedCovariance

if TYPE_CHECKING:
    from jaxtyping import Float
    from torch import Tensor


class CustomFactorizedCovariance(FactorizedCovariance):
    r"""Covariance of a Gaussian parameter with a factorized structure.

    Assumes the covariance is factorized as a product of a square matrix and its transpose.

    ..math::
        \mathbf{\Sigma} = \mathbf{S} \mathbf{S}^\top

    :param rank: Rank of the covariance matrix. If `None`, the rank is set to the total number
        of mean parameters.
    """

    def __init__(self, 
        rank: int | None = None,
        scale_init_weight: float=1.0,
        scale_init_bias: float=1.0,
        scale_lr_weight: float=1.0,
        scale_lr_bias: float=1.0,
        scale_forward_weight: float=1.0,
        scale_forward_bias: float=1.0,
    ):
        super().__init__(rank = rank)
        self.__class__.__name__ = "FactorizedCovariance"

        self.scales_init = {'weight': scale_init_weight, 'bias': scale_init_bias}
        self.scales_lr = {'weight': scale_lr_weight, 'bias': scale_lr_bias}
        self.scales_forward = {'weight': scale_forward_weight, 'bias': scale_forward_bias}

    def reset_parameters(
        self,
        mean_parameter_scales: dict[str, float] | float = 1.0,
    ) -> None:
        """Reset the parameters of the covariance matrix.

        Initalizes the parameters of the covariance matrix with a
        scale that is given by the mean parameter scales and a
        covariance-specific scaling that depends on the structure of the covariance matrix.

        :param mean_parameter_scales: Scales of the mean parameters. If a dictionary
            keys are the names of the mean parameters. If a float, all covariance
            parameters are initialized with the same scale.
        """
        if isinstance(mean_parameter_scales, float):
            mean_parameter_scales = {
                name: mean_parameter_scales for name in self.factor.keys()
            }

        for name, param in self.factor.items():
            if param is not None:
                nn.init.normal_(
                    param,
                    mean=0,
                    std=mean_parameter_scales[name] * self.scales_init[name] / math.sqrt(self.rank),
                )
    
    def _stacked_parameters(self) -> Float[Tensor, "parameter"]:
        """Stack parameters into a single tensor.""" 
        stacked_parameters = torch.vstack(
            [
                self.scales_forward[factor_name] * factor_param.view(-1, self.rank)
                for factor_name, factor_param in self.factor.items()
                if factor_param is not None
            ]
        )
        return stacked_parameters

    @property
    def lr_scaling(self) -> dict[str, float]:
        """Compute the learning rate scaling for the covariance parameters."""
        return {
            "factor." + name: 1 * self.scales_lr[name] / math.sqrt(self.rank) for name in self.factor.keys()
        }
