import os
import torch
import numpy as np

from llm_router.router.base_router import BaseRouter, RouterInput, RouterOutput, PREFERENCES
from llm_router.data.utils import get_costs
from llm_router.model.registry import get_model
from llm_router.model.utils import Trainer

class ClsRouter(BaseRouter):
    
    @property
    def name(self):
        return f"cls_router"
    
    def fit(self, trainset, valset, configs):    
        self.benchmark = trainset.get_benchmark()
        
        self.max_cost = self.benchmark["costs"].max().item()
            
        # load model
        self.model = get_model(configs.model)
        self.model.to(configs.model.device)
        self.model.initialize(trainset)
        # trainer
        trainer = Trainer(self.model, trainset, valset, configs.training)
        ckpt_dir = os.path.join(configs.training.output_dir, "model", configs.model.ckpt)
        if not os.path.exists(ckpt_dir):
            trainer.train()
        trainer.load(ckpt_dir)
        
    def route(self, router_input: RouterInput):
        routing_id = self.model.predict(router_input)
        routing_model = list(router_input.routing_config.keys())[routing_id]
        
        outputs = {}
        preference = self.config.preference / self.max_cost
        outputs[preference] = RouterOutput(
            idx=router_input.idx,
            routing_config=router_input.routing_config,
            scores=router_input.scores,
            costs=router_input.costs,
            input_tokens=router_input.input_tokens,
            output_tokens=router_input.output_tokens,
            routing_id=routing_id,
            routing_model=routing_model,
            info={
                "max_cost": self.max_cost
            },
        )
            
        return outputs