from typing import Any, Dict, List, Optional, Union

from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase
import torch
from collections import defaultdict


@dataclass
class RewardScoreDataCollatorWithPadding:
    r"""
    Reward Score DataCollator class that pads the inputs to the maximum length of the batch, and handles the score.

    Args:
        tokenizer (`PreTrainedTokenizerBase`):
            The tokenizer used for encoding the data.
        padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
            padding_strategy to pass to the tokenizer.
        pad_to_multiple_of (`Optional[int]`, `optional`, defaults to `None`):
            If set will pad the sequence to a multiple of the provided value.
        return_tensors (`str`, `optional`, defaults to `"pt"`):
            The tensor type to use.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str] = True
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"
    torch_dtype: torch.dtype = torch.float32

    features_to_collate = [
        "score",
        "popularity",
        "score_clean",
        "score_no_popularity",
        "ctr",
        "ctr_scaled",
    ]

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        target_dtype = getattr(torch, self.torch_dtype.split(".")[-1])

        features_output = []
        prompts = []
        prompt_features = []
        other_variables = defaultdict(list)
        for feature in features:
            if "prompt" in feature:
                prompt_features.append(
                    {
                        "input_ids": feature["prompt_input_ids"],
                        "attention_mask": feature["prompt_attention_mask"],
                    }
                )
                prompts.append(feature["prompt"])
            features_output.append(
                {
                    "input_ids": feature["input_ids"],
                    "attention_mask": feature["attention_mask"],
                }
            )
            for key in self.features_to_collate:
                other_variables[key].append(feature.get(key, 0))

        batch = self.tokenizer.pad(
            features_output,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )

        batch = {
            "input_ids": batch["input_ids"],
            "attention_mask": batch["attention_mask"],
            "return_loss": True,
        }

        if prompt_features:
            prompt_batch = self.tokenizer.pad(
                prompt_features,
                padding=self.padding,
                pad_to_multiple_of=self.pad_to_multiple_of,
                return_tensors=self.return_tensors,
            )
            batch["prompt_input_ids"] = prompt_batch["input_ids"]
            batch["prompt_attention_mask"] = prompt_batch["attention_mask"]
            batch["prompt"] = prompts

        for key, value in other_variables.items():
            batch[key] = torch.tensor(value, dtype=target_dtype).unsqueeze(1)

        return batch
