from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from torch.nn.utils.rnn import pad_sequence
from trl.trainer.utils import pad
import torch


@dataclass
class RMAlignmentDataCollatorWithPadding:
    r"""
    DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
    Args:
        pad_token_id (`int` defaults to 0):
            The tokenizer's pad_token_id.
        label_pad_token_id (`int`, defaults to -100):
            The label used for masking.
        is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
            Whether or not you model has an encoder_decoder architecture.
    """

    pad_token_id: int = 0
    label_pad_token_id: int = -100
    is_encoder_decoder: Optional[bool] = False

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        assert not self.is_encoder_decoder
        # first, pad everything to the same length
        padded_batch = {}
        for k in features[0].keys():
            if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")):
                # Set padding value based on the key
                if k.endswith("_input_ids"):
                    if self.pad_token_id is None:
                        raise ValueError(
                            "Padding is enabled, but the tokenizer is not configured with a padding token."
                            " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)"
                            " before calling the trainer."
                        )
                    padding_value = self.pad_token_id
                elif k.endswith("_labels"):
                    padding_value = self.label_pad_token_id
                elif k.endswith("_attention_mask"):
                    padding_value = 0
                elif k.endswith("_pixel_values"):
                    padding_value = 0  # TODO: check if this is correct
                else:
                    raise ValueError(f"Unexpected key in batch '{k}'")

                # Set padding side based on the key
                if k in ["prompt_input_ids", "prompt_attention_mask"]:
                    padding_side = "left"
                else:
                    padding_side = "right"

                # Set the dtype
                if k.endswith("_pixel_values"):
                    dtype = torch.float32  # will be downcasted if necessary by the Trainer
                else:
                    dtype = torch.int64

                # Convert to tensor and pad
                to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features]
                padded_batch[k] = pad(to_pad, padding_value=padding_value, padding_side=padding_side)
            elif k.endswith("_logps"):
                # the cached reference model logprobs
                padded_batch[k] = torch.tensor([ex[k] for ex in features])
            else:
                padded_batch[k] = [ex[k] for ex in features]

        return padded_batch


if __name__ == "__main__":
    pass
