from typing import Any, Dict, List


class PolicyCollator: 
    def __init__(self, cfg):
        self.cfg = cfg

    def __call__(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Take batch of outputs and (1) format them into chat format, and (2) tokenize them into tensors.
        """
        questions = [output['prompt'] for output in outputs]
        responses = [output['response'] for output in outputs]
        query_responses = self.query_builder.build_query_responses(questions, responses)
        tokenized_query_responses = self.query_builder.tokenizer(query_responses, return_tensors="pt", padding=True,)

        # Add labels
        tokenized_query_responses["labels"] = tokenized_query_responses["input_ids"].clone()

        # Mask input tokens in labels
        queries = self.query_builder.build_queries(questions)
        tokenized_queries = self.query_builder.tokenizer(queries, return_tensors="pt", padding=True,)
        # Mask out labels for the query tokens (only keep labels for response tokens)
        for i, length in enumerate(tokenized_queries["attention_mask"].sum(dim=1)):
            tokenized_query_responses["labels"][i, :length] = -100

        return tokenized_query_responses, [output['id'] for output in outputs]