from dataclasses import dataclass
from typing import Any, Dict, List, Optional


@dataclass
class MultiTaskDataCollator:
    tokenizer: Any

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        # Tokenize a batch of instructions
        instructions = [f["instruction"] for f in features]
        inputs = self.tokenizer(
            instructions,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )

        qa_outputs = [f["output"] if f["task_type"] == "qa" else "" for f in features]

        labels = self.tokenizer(qa_outputs, padding=True, truncation=True, max_length=512).input_ids
        labels = torch.where(labels == self.tokenizer.pad_token_id, -100, labels)
        inputs["labels"] = labels


        inputs["task_types"] = [f["task_type"] for f in features]


        class_labels = [f.get("class_label", -100) for f in features]
        inputs["class_labels"] = torch.tensor(class_labels, dtype=torch.long)


        regression_values = [f.get("regression_value", -100.0) for f in features]
        inputs["regression_values"] = torch.tensor(regression_values, dtype=torch.float)

        return inputs
