import logging
import random
import torch
import numpy as np

logger = logging.getLogger(__name__)

def get_cost(num_prompt_tokens, num_completion_tokens, model_config):
    return (
        num_prompt_tokens * model_config.get("prompt", 0.0) / 1000000 + 
        num_completion_tokens * model_config.get("completion", 0.0) / 1000000 +
        model_config.get("request", 0.0)
    )

def get_costs(input_tokens, output_tokens, routing_config):
    assert len(input_tokens) == len(output_tokens) and len(input_tokens) == len(routing_config.keys())
    
    return [get_cost(inp, out, cfg) for inp, out, cfg in zip(input_tokens, output_tokens, routing_config.values())]


class RoutingDataCollator:
    
    def __call__(self, batch: list[dict]):
        return {
            "routing_config": batch[0]["routing_config"],
            "idx": torch.tensor([ex["idx"] for ex in batch]),
            "prompts": [ex["prompt"] for ex in batch],
            "embeddings": torch.stack([ex["embedding"] for ex in batch]),
            "scores": torch.stack([ex["scores"] for ex in batch]),
            "costs": torch.stack([ex["costs"] for ex in batch]),
            "input_tokens": torch.stack([ex["input_tokens"] for ex in batch]),
            "output_tokens": torch.stack([ex["output_tokens"] for ex in batch]),
        }