from lightning_utilities.core.rank_zero import (
    rank_zero_only,
    rank_zero_warn,
    rank_zero_info,
)
import torch
from typing import Dict
from collections import OrderedDict


def load_model_from_checkpoint(
    checkpoint_path: str, state_dict: Dict, submodule: str | None = None
) -> OrderedDict:
    rank_zero_info(f"Loading model from {checkpoint_path}")
    state_dict_loaded = torch.load(
        checkpoint_path,
        map_location="cpu",
        weights_only=False,
    )["state_dict"]

    new_state_dict = OrderedDict()

    missing_module = set()

    for k, v in state_dict.items():
        k_new = f"{submodule}.{k}" if submodule else k
        if k_new not in state_dict_loaded or "pad_buffer" in k:
            new_state_dict[k] = v
            parent_module = ".".join(k.split(".")[:3])
            missing_module.add(parent_module)
        else:
            new_state_dict[k] = state_dict_loaded[k_new]

    for module in missing_module:
        rank_zero_info(f"\033[93mWarning: {module} not found in checkpoint\033[0m")

    return new_state_dict


def freeze_module(module: torch.nn.Module) -> torch.nn.Module:
    module.eval()
    for param in module.parameters():
        param.requires_grad = False
    return module


def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    """Make mask tensor containing indices of padded part.

    See description of make_non_pad_mask.

    Args:
        lengths (torch.Tensor): Batch of lengths (B,).
    Returns:
        torch.Tensor: Mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    """
    batch_size = lengths.size(0)
    max_len = max_len if max_len > 0 else lengths.max().item()
    seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask
