import logging
import math
from typing import Callable
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn

logger = logging.Logger(__file__)


def remove_key_prefix_factory(prefix: str = "module."):
    def func(
        model_dict: dict[str, torch.Tensor], state_dict: dict[str,
                                                              torch.Tensor]
    ) -> dict[str, torch.Tensor]:

        state_dict = {
            key[len(prefix):]: value
            for key, value in state_dict.items() if key.startswith(prefix)
        }
        return state_dict

    return func


def merge_matched_keys(
    model_dict: dict[str, torch.Tensor], state_dict: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
    """
    Args:
    model_dict:
        The state dict of the current model, which is going to load pretrained parameters
    state_dict:
        A dictionary of parameters from a pre-trained model.

    Returns:
        dict[str, torch.Tensor]:
            The updated state dict, where parameters with matched keys and shape are 
            updated with values in `state_dict`.
    """
    pretrained_dict = {}
    mismatch_keys = []
    for key, value in state_dict.items():
        if key in model_dict and model_dict[key].shape == value.shape:
            pretrained_dict[key] = value
        else:
            mismatch_keys.append(key)
    logger.info(
        f"Loading pre-trained model, with mismatched keys {mismatch_keys}"
    )
    model_dict.update(pretrained_dict)
    return model_dict


def load_pretrained_model(
    model: nn.Module,
    ckpt_or_state_dict: str | Path | dict[str, torch.Tensor],
    state_dict_process_fn: Callable = merge_matched_keys
) -> None:
    state_dict = ckpt_or_state_dict
    if not isinstance(state_dict, dict):
        state_dict = torch.load(ckpt_or_state_dict, "cpu")

    model_dict = model.state_dict()
    state_dict = state_dict_process_fn(model_dict, state_dict)
    model.load_state_dict(state_dict)


def create_mask_from_length(
    lengths: torch.Tensor, max_length: int | None = None
):
    if max_length is None:
        max_length = max(lengths)
    idxs = torch.arange(max_length).reshape(1, -1)  # (1, max_length)
    mask = idxs.to(lengths.device) < lengths.view(-1, 1)
    # (1, max_length) < (batch_size, 1) -> (batch_size, max_length)
    return mask


def loss_with_mask(
    loss: torch.Tensor,
    mask: torch.Tensor,
    reduce: bool = True
) -> torch.Tensor:
    """
    Apply a mask to the loss tensor and optionally reduce it.

    Args:
        loss: Tensor of shape (b, t, ...) representing the loss values.
        mask: Tensor of shape (b, t) where 1 indicates valid positions and 0 indicates masked positions.
        reduce: If True, return a single scalar value; otherwise, return a tensor of shape (b,).

    Returns:
        torch.Tensor: A scalar if reduce is True, otherwise a tensor of shape (b,).
    """
    expanded_mask = mask[(..., ) + (None, ) * (loss.ndim - mask.ndim)]
    expanded_mask = expanded_mask.expand_as(loss)
    masked_loss = loss * expanded_mask

    sum_dims = tuple(range(1, loss.ndim))
    loss_sum = masked_loss.sum(dim=sum_dims)
    mask_sum = expanded_mask.sum(dim=sum_dims)
    loss = loss_sum / mask_sum

    if reduce:
        return loss.mean()
    else:
        return loss


def convert_pad_shape(pad_shape: list[list[int]]):
    l = pad_shape[::-1]
    pad_shape = [item for sublist in l for item in sublist]
    return pad_shape


def create_alignment_path(duration: torch.Tensor, mask: torch.Tensor):
    device = duration.device

    b, t_x, t_y = mask.shape
    cum_duration = torch.cumsum(duration, 1)

    cum_duration_flat = cum_duration.view(b * t_x)
    path = create_mask_from_length(cum_duration_flat, t_y).float()
    path = path.view(b, t_x, t_y)
    # take the diff on the `t_x` axis
    path = path - torch.nn.functional.pad(
        path, convert_pad_shape([[0, 0], [1, 0], [0, 0]])
    )[:, :-1]
    path = path * mask
    return path


def trim_or_pad_length(x: torch.Tensor, target_length: int, length_dim: int):
    """
    Adjusts the size of the specified dimension of tensor x to match `target_length`.
    
    Args:
        x:
            Input tensor.
        target_length: 
            Desired size of the specified dimension.
        length_dim: 
            The dimension to modify.
    
    Returns:
        torch.Tensor: The adjusted tensor.
    """
    current_length = x.shape[length_dim]

    if current_length > target_length:
        # Truncate the tensor
        slices = [slice(None)] * x.ndim
        slices[length_dim] = slice(0, target_length)
        return x[tuple(slices)]

    elif current_length < target_length:
        # Pad the tensor with zeros
        pad_shape = list(x.shape)
        pad_length = target_length - current_length

        pad_shape[length_dim] = pad_length  # Shape for left padding
        padding = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)

        return torch.cat([x, padding], dim=length_dim)

    return x


def concat_non_padding(
    seq1: torch.Tensor, mask1: torch.BoolTensor, seq2: torch.Tensor,
    mask2: torch.BoolTensor
):
    """
    Args
        seq1 : Tensor (B, L1, E)
            First sequence.
        mask1 : BoolTensor (B, L1)
            True for valid tokens in seq1, False for padding.
        seq2 : Tensor (B, L2, E)
            Second sequence.
        mask2 : BoolTensor (B, L2)
            True for valid tokens in seq2, False for padding.

    Returns
        concat_seq : Tensor (B, L1+L2, E)
            Both sequences concatenated; valid tokens are left-aligned,
            padding on the right is 0.
        concat_mask: BoolTensor (B, L1+L2)
            Mask for the concatenated sequence.
        perm : LongTensor (B, L1+L2)
            Permutation that maps **original indices → new indices**.
            Needed for restoring the original sequences.
    """
    mask1, mask2 = mask1.bool(), mask2.bool()
    B, L1, E = seq1.shape
    L2 = seq2.size(1)
    L = L1 + L2

    seq_cat = torch.cat([seq1, seq2], dim=1)  # (B, L, E)
    mask_cat = torch.cat([mask1, mask2], dim=1)  # (B, L)

    # ----- Key step: stable sort so that all valid tokens move to the left -----
    # Padding positions get +L, guaranteeing the largest “score” → sorted to the end.
    positions = torch.arange(L, device=seq_cat.device).unsqueeze(0)  # (1, L)
    sort_score = positions + (~mask_cat) * L
    perm = sort_score.argsort(dim=1, stable=True)  # (B, L)

    # Build concatenated sequence & mask
    gather_idx = perm.unsqueeze(-1).expand(-1, -1, E)  # (B, L, E)
    concat_seq = seq_cat.gather(1, gather_idx)
    concat_mask = mask_cat.gather(1, perm)

    # Explicitly zero out the right-hand padding region for safety
    concat_seq = concat_seq * concat_mask.unsqueeze(-1)

    return concat_seq, concat_mask, perm


def restore_from_concat(
    concat_seq: torch.Tensor, mask1: torch.BoolTensor, mask2: torch.BoolTensor,
    perm: torch.LongTensor
):
    """
    Restore (seq1, seq2) from the concatenated sequence produced by
    `concat_non_padding`, using the returned permutation `perm`.
    Fully vectorised — no Python loops.
    """
    mask1, mask2 = mask1.bool(), mask2.bool()
    B, L1 = mask1.shape
    L2 = mask2.size(1)
    E = concat_seq.size(-1)

    # Inverse permutation: maps **new_idx → old_idx**
    inv_perm = torch.empty_like(perm)
    inv_perm.scatter_(
        1, perm,
        torch.arange(L1 + L2, device=perm.device).unsqueeze(0).expand(B, -1)
    )

    # Bring tokens back to their original order
    gather_idx = inv_perm.unsqueeze(-1).expand(-1, -1, E)
    seq_cat_rec = concat_seq.gather(1, gather_idx)  # (B, L1+L2, E)

    # Split back into the two sequences and mask out padding positions
    seq1_restore, seq2_restore = seq_cat_rec.split([L1, L2], dim=1)
    seq1_restore = seq1_restore * mask1.unsqueeze(-1)
    seq2_restore = seq2_restore * mask2.unsqueeze(-1)

    return seq1_restore, seq2_restore


def contains_nan(data):
    """check if data contains NaN"""
    if isinstance(data, torch.Tensor):
        return torch.isnan(data).any().item()
    elif isinstance(data, np.ndarray):
        return np.isnan(data).any()
    elif isinstance(data, float):
        return math.isnan(data)
    elif isinstance(data, (list, tuple)):
        return any(contains_nan(x) for x in data)
    elif isinstance(data, dict):
        return any(contains_nan(v) for v in data.values())
    return False


def check_nan_in_batch(batch):
    """check if batch contains NaN and return nan audio ids"""
    assert type(batch)==dict,"batch type error" 
    nan_audio_ids=[]
    audio_ids=batch["audio_id"]
    audio_id2content={}
    for idx,audio_id in enumerate(audio_ids):
        content=[]
        for k,v in batch.items():
            if k=="audio_id":
                continue
            content.append(v[idx])
        audio_id2content[audio_id]=content
    
    for audio_id,content in audio_id2content.items():
        if contains_nan(content):
            nan_audio_ids.append(audio_id)
            print(f"{audio_id} contains NaN")
    return nan_audio_ids
    
