import os
from typing import Any, Dict, Optional
from unittest.result import failfast

import torch
from torch.nn.parallel import DistributedDataParallel


def svd_decompose(
    weight_tensor: torch.tensor, rank: int, tau: Optional[int] = 32
) -> tuple[torch.tensor, torch.tensor, torch.tensor]:
    """
    Apply truncated SVD decomposition to the given weight tensor. Singular values are automatically multiplied in V.
    If the rank is higher than the rank of the weight tensor the function will pad the factor U with zeros.
    The factor V will be padded with TopK row vectors divided by tau.

    Args:
        weight_tensor (`torch.tensor`)
            The weight tensor to decompose.
        rank (`int`)
            The rank inut for truncated SVD.
        eta (`int`)
            The alpha value for dividing the TopK feature vectors of V.
    Returns:
        U (`torch.tensor`), V (`torch.tensor`)
            The decomposed matrices where V is multiplied with the singular values S.
    """
    original_device = weight_tensor.device
    original_dtype = weight_tensor.dtype
    if original_device != torch.device("cpu"):
        weight_tensor = weight_tensor.to("cpu")
    if weight_tensor.dtype != torch.float32:
        weight_tensor = weight_tensor.to(torch.float32)

    U, S, V = torch.linalg.svd(weight_tensor, full_matrices=False)
    V = V * S.unsqueeze(1)
    d = min(U.shape[0], U.shape[1])
    if rank > d:
        k = rank - d
        device = weight_tensor.device
        new_U_neurons = torch.zeros(U.shape[0], k, device=device)
        new_V_neurons = V[:k, :] / tau
        U = torch.cat([U, new_U_neurons], dim=1)
        V = torch.cat([V, new_V_neurons], dim=0)
        while V.shape[0] < rank:
            V = torch.cat([V, new_V_neurons], dim=0)[:rank, :]
    else:
        U = U[:, :rank]
        V = V[:rank, :]

    # Cast back to original dtype and device
    U = U.contiguous().to(original_device, dtype=original_dtype)
    V = V.to(original_device, dtype=original_dtype)
    return U, V


def compute_sparse_r(
    n_blocks: int,
    d: int,
    act_d: int,
    p_budget: float,
    sparsity: float,
    num_ffs: int,
    rounding: bool = False,
):
    """Returns rank of the model when initialized with n_blocks to be concatenated,
    model dimension d, parameter budget of FCs and sparsity.
    Args:
        n_blocks (int): Number of blocks to be concatenated.
        d (int): Model dimension.
        p_budget (float): Parameter budget of FCs.
        sparsity (float): Sparsity of the model.
    Returns:
        int: Rank of the compressed and sparse model.
    """
    A = act_d * num_ffs * n_blocks
    R = A * d * p_budget / (d + A * (1 - sparsity))
    R = int(R)
    if rounding:
        R = 64 * (R // 64)
    return R


def compute_sparse_r_with_bit_masks(
    n_blocks: int,
    d: int,
    act_d: int,
    p_budget: float,
    sparsity: float,
    num_ffs: int,
    weight_precision_bits: int = 32,
    rounding: bool = False,
) -> int:
    """Compute the rank (``R``) for the compressed & sparse model **including**
    the storage overhead of binary sparsity masks.

    The original :pyfunc:`compute_sparse_r` function assumes that a *parameter* is
    the atomic unit of storage. When we introduce binary masks to indicate the
    sparsity pattern of the *V* factor, every weight element now also requires
    **one extra bit**. In a model where weights are stored with
    ``weight_precision_bits`` bits (e.g. 32 for *fp32*, 8 for *int8*, 4 for
    *int4*), ``weight_precision_bits`` mask bits correspond to the storage of a
    single additional parameter.  This function adjusts the rank computation so
    that the total *parameter-equivalent* storage (weights **plus** masks) stays
    within the provided parameter budget.

    Args:
        n_blocks (int): Number of blocks to be concatenated.
        d (int): Model dimension.
        act_d (int): Activation/hidden dimension of the MLPs.
        p_budget (float): Parameter budget for the feed-forward layers expressed
            as a fraction of the original parameter count.
        sparsity (float): Fraction \[0–1] of zeros in the *V* factor.
        num_ffs (int): Number of feed-forward layers per block (2 or 3).
        weight_precision_bits (int): Bit-width used to store **one** weight
            value. Typical values are 32 (*fp32*), 16 (*fp16*), 8 (*int8*), or 4
            (*int4*).
        rounding (bool, optional): If *True*, round the resulting rank down to
            the nearest multiple of 64.
    Returns:
        int: The rank ``R`` that satisfies the budget.
    """
    A = act_d * num_ffs * n_blocks
    R = A * d * p_budget / (d + A * (1 - sparsity) + A / weight_precision_bits)
    R = int(R)

    if rounding:
        R = 64 * (R // 64)

    return R


class DDPWrappedGetAttr(DistributedDataParallel):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)


def get_single_process_model_state_from_distributed_state(
    model_state: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
    """Returns model state from a distributed training run in a format
    suitable for a single process model.

    In distributed training, `module.<param_name>` is appended to every
    parameter. If we wish to test/train this model further in a single
    process, we simply strip the `module` prefix to match keys expected in
    the model.

    Returns:
        Dict[str, torch.Tensor]: Model state ordered dict from distributed
            trained model for loading in single process.
    """
    return {".".join(k.split(".")[1:]): v for k, v in model_state.items()}


def get_world_size() -> int:
    return int(os.environ.get("WORLD_SIZE", 1))


def get_mlp_compression_rate(
    model: torch.nn.Module,
    total_compression_rate: float,
    num_params_for_bit_masks: int,
) -> float:
    """
    Compute the compression rate for the MLP layers of the model.
    """

    num_total_params = 0
    num_mlp_params = 0
    for name, param in model.named_parameters():
        num_total_params += param.numel()
        if "gate_proj" in name:
            num_mlp_params += param.numel()
        if "up_proj" in name:
            num_mlp_params += param.numel()
        if "down_proj" in name:
            num_mlp_params += param.numel()
        if "fc1" in name:
            num_mlp_params += param.numel()
        if "fc2" in name:
            num_mlp_params += param.numel()
    return (
        num_total_params * total_compression_rate + num_params_for_bit_masks
    ) / num_mlp_params


def get_model_compression_rate(
    model: torch.nn.Module,
    mlp_compression_rate: float,
) -> float:
    """
    Compute the overall compression rate for the model given MLP compression rate.
    """
    num_total_params = 0
    num_mlp_params = 0
    for name, param in model.named_parameters():
        num_total_params += param.numel()
        if "gate_proj" in name:
            num_mlp_params += param.numel()
        if "up_proj" in name:
            num_mlp_params += param.numel()
        if "down_proj" in name:
            num_mlp_params += param.numel()
        if "fc1" in name:
            num_mlp_params += param.numel()
        if "fc2" in name:
            num_mlp_params += param.numel()

    num_non_mlp_params = num_total_params - num_mlp_params
    compressed_mlp_params = num_mlp_params * mlp_compression_rate
    total_compressed_params = num_non_mlp_params + compressed_mlp_params

    return total_compressed_params / num_total_params


def get_mlp_fraction_of_model(
    model: torch.nn.Module,
    mlp_fraction: float = 1.0,
) -> float:
    """Return the share of the *whole* model accounted for by a given fraction of
    its MLP parameters.

    Args:
        model (torch.nn.Module): The model whose parameters are examined.
        mlp_fraction (float): Fraction \(0–1\] of the MLP parameters we are
            interested in. For example, ``0.9`` means *90 % of the parameters in
            the MLP layers*.

    Returns:
        float: The ratio
            ``mlp_fraction * #MLP_parameters / #total_parameters``.

    Example:
        >>> ratio = get_mlp_fraction_of_model(model, mlp_fraction=0.9)
        >>> print(f"90 % of the MLP weights correspond to {ratio:.2%} of the whole model")
    """

    if not (0.0 <= mlp_fraction <= 1.0):
        raise ValueError("`mlp_fraction` must be between 0 and 1 (inclusive).")

    num_total_params = 0
    num_mlp_params = 0

    for name, param in model.named_parameters():
        param_count = param.numel()
        num_total_params += param_count
        if (
            "gate_proj" in name
            or "up_proj" in name
            or "down_proj" in name
            or "fc1" in name
            or "fc2" in name
        ):
            num_mlp_params += param_count

    # How much of the whole model those `mlp_fraction` * MLP parameters represent.
    return (mlp_fraction * num_mlp_params) / num_total_params
