import numpy as np

from llm_router.router.base_router import BaseRouter, RouterInput, RouterOutput, PREFERENCES
from llm_router.data.utils import get_costs

class OracleRouter(BaseRouter):
    
    @property
    def name(self):
        return f"oracle_router"
    
    def fit(self, trianset, valset, configs):
        pass
    
    def route(self, router_input: RouterInput):
        scores = router_input.scores
        costs = router_input.costs
        assert len(scores) == len(costs)
        
        outputs = {}
        for preference in PREFERENCES:
            q = [s - preference * c for s, c in zip(scores, costs)]
            routing_id = np.argmax(q)
            routing_model = list(router_input.routing_config.keys())[routing_id]
            outputs[preference] = RouterOutput(
                idx=router_input.idx,
                routing_config=router_input.routing_config,
                scores=router_input.scores,
                costs=router_input.costs,
                input_tokens=router_input.input_tokens,
                output_tokens=router_input.output_tokens,
                routing_id=routing_id,
                routing_model=routing_model,
                info={},
            )
            
        return outputs