from __future__ import annotations

from typing import Any, Dict, List

import torch
from torch.nn.utils.rnn import pad_sequence

IGNORE_INDEX = -100


class MaskPruningCollator:
    def __init__(self, pad_token_id: int, label_pad_token_id: int = IGNORE_INDEX, pad_to_multiple_of: int = 8):
        self.pad_token_id = pad_token_id
        self.label_pad_token_id = label_pad_token_id
        self.pad_to_multiple_of = pad_to_multiple_of

    def _pad(self, sequences: List[torch.Tensor], value: int) -> torch.Tensor:
        if self.pad_to_multiple_of is None:
            return pad_sequence(sequences, batch_first=True, padding_value=value)

        max_len = max(seq.size(0) for seq in sequences)
        if max_len % self.pad_to_multiple_of != 0:
            max_len = ((max_len // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of

        padded = []
        for seq in sequences:
            pad_len = max_len - seq.size(0)
            if pad_len > 0:
                pad_tensor = torch.full((pad_len,), value, dtype=seq.dtype)
                padded.append(torch.cat([seq, pad_tensor], dim=0))
            else:
                padded.append(seq)
        return torch.stack(padded, dim=0)

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        prompt_texts = [feat.pop("prompt_text") for feat in features]
        tasks = [feat.pop("task") for feat in features]

        input_ids = [torch.tensor(feat["input_ids"], dtype=torch.long) for feat in features]
        attention_masks = [torch.tensor(feat["attention_mask"], dtype=torch.long) for feat in features]
        labels = [torch.tensor(feat["labels"], dtype=torch.long) for feat in features]

        batch_input_ids = self._pad(input_ids, self.pad_token_id)
        batch_attention = self._pad(attention_masks, 0)
        batch_labels = self._pad(labels, self.label_pad_token_id)

        return {
            "input_ids": batch_input_ids,
            "attention_mask": batch_attention,
            "labels": batch_labels,
            "prompt_text": prompt_texts,
            "task": tasks,
        }

