import torch
import torch.nn as nn
import torch.nn.functional as F


def project_feature(feat: torch.Tensor, proj: nn.Module) -> torch.Tensor:
    """
    Project the input feature to match the decoder's expected input dimension.

    Args:
        feat (torch.Tensor): Input feature, e.g., CLIP image embedding [B, input_dim]
        proj (nn.Module): Projection module (e.g., Linear or Identity)

    Returns:
        torch.Tensor: Projected feature [B, decoder_dim]
    """
    return proj(feat)


def retrieve_from_cache(query_feat: torch.Tensor, text_cache: torch.Tensor, top_k: int = 1) -> torch.Tensor:
    """
    Retrieve the most similar feature(s) from a cached set of text features via cosine similarity.

    Args:
        query_feat (torch.Tensor): Query feature [B, D]
        text_cache (torch.Tensor): Cached text features [N, D]
        top_k (int): Number of top matched features to return

    Returns:
        torch.Tensor: Retrieved top-k features [B, top_k, D] or [B, D] if top_k = 1
    """
    q = F.normalize(query_feat, dim=-1)                     # [B, D]
    c = F.normalize(text_cache.to(query_feat.device), dim=-1)  # [N, D]
    sim = torch.matmul(q, c.T)                              # [B, N], cosine similarity
    top_idx = sim.topk(top_k, dim=-1).indices               # [B, top_k]
    return text_cache[top_idx]                              # shape: [B, top_k, D] or [B, D]


def decode_category(decoder, tokenizer, text_feat: torch.Tensor, max_len: int = 77) -> list:
    """
    Use the decoder model to generate category tokens from a text embedding.

    Args:
        decoder (nn.Module): Decoder model that generates token logits from embedding
        tokenizer: Tokenizer used to decode token IDs into strings
        text_feat (torch.Tensor): Text embedding [B, D]
        max_len (int): Maximum token length to decode

    Returns:
        list: A list of decoded strings (one per input sample)
    """
    B = text_feat.size(0)
    bos_token = tokenizer.bos_token_id or tokenizer.cls_token_id or 0
    input_ids = torch.full((B, 1), bos_token, dtype=torch.long, device=text_feat.device)  # [B, 1]

    generated = input_ids
    for _ in range(max_len):
        logits = decoder(text_feat, generated)               # [B, vocab_size, seq_len]
        next_token = logits[:, :, -1].argmax(dim=1, keepdim=True)  # [B, 1]
        generated = torch.cat([generated, next_token], dim=1)      # append next token

        # Stop early if all predictions are EOS
        if (next_token == tokenizer.eos_token_id).all():
            break

    return [tokenizer.decode(seq, skip_special_tokens=True).strip() for seq in generated]


def predict_category(raw_feat: torch.Tensor,
                     proj: nn.Module,
                     text_cache: torch.Tensor,
                     decoder,
                     tokenizer,
                     max_len: int = 77,
                     top_k: int = 1) -> list:
    """
    Full prediction pipeline: project → retrieve → decode.

    Args:
        raw_feat (torch.Tensor): Raw input feature [B, input_dim]
        proj (nn.Module): Projection module to match decoder dim
        text_cache (torch.Tensor): Cached text embeddings [N, decoder_dim]
        decoder (nn.Module): Decoder model (e.g., TransformerDecoder)
        tokenizer: Tokenizer used to decode token IDs into text
        max_len (int): Maximum length of decoded text sequence
        top_k (int): Number of top similar features to retrieve (only first is used for decode)

    Returns:
        list: List of predicted labels (strings) for each input in batch
    """
    projected_feat = project_feature(raw_feat, proj)                       # [B, decoder_dim]
    matched_text_feat = retrieve_from_cache(projected_feat, text_cache, top_k=top_k)  # [B, top_k, D]

    # If multiple matches are returned, take the top-1
    if matched_text_feat.dim() == 3:
        matched_text_feat = matched_text_feat[:, 0, :]                     # [B, D]

    labels = decode_category(decoder, tokenizer, matched_text_feat, max_len=max_len)
    return labels
