import os
import pickle
import torch
import numpy as np
from sklearn.neural_network import MLPClassifier

from llm_router.router.base_router import BaseRouter, RouterInput, RouterOutput, PREFERENCES
from llm_router.data.utils import get_costs

class MLPClassifierRouter(BaseRouter):
    
    @property
    def name(self):
        return f"mlp_classifier_router"
    
    def fit(self, trainset, valset, configs):
        self.benchmark = trainset.get_benchmark()
        self.num_models = self.benchmark["scores"].shape[1]
        
        self.max_cost = self.benchmark["costs"].max().item()
        
        if os.path.exists(os.path.join(configs.training.output_dir, "mlp.pkl")):
            with open(os.path.join(configs.training.output_dir, "mlp.pkl"), "rb") as f:
                self.mlp = pickle.load(f)
            return
        
        scores = self.benchmark["scores"].cpu().numpy() # [B,M]
        costs = self.benchmark["costs"].cpu().numpy() # [B,M]
        
        q = scores - self.config.preference * costs / self.max_cost
        routing_label = np.argmax(q, axis=1)
        
        self.mlp = MLPClassifier(
            hidden_layer_sizes=self.config.hidden_dims,
            max_iter=self.config.max_iter,
            learning_rate_init=self.config.lr
        ).fit(
            self.benchmark["embeddings"].cpu().numpy(),
            routing_label
        )
        
        os.makedirs(configs.training.output_dir, exist_ok=True)
        with open(os.path.join(configs.training.output_dir, "mlp.pkl"), "wb") as f:
            pickle.dump(self.mlp, f)
            
    def route(self, router_input: RouterInput):
        routing_id = self.mlp.predict(router_input.embedding.unsqueeze(0).cpu().numpy())[0]
        routing_model = list(router_input.routing_config.keys())[routing_id]
        
        outputs = {}
        preference = self.config.preference / self.max_cost
        outputs[preference] = RouterOutput(
            idx=router_input.idx,
            routing_config=router_input.routing_config,
            scores=router_input.scores,
            costs=router_input.costs,
            input_tokens=router_input.input_tokens,
            output_tokens=router_input.output_tokens,
            routing_id=routing_id,
            routing_model=routing_model,
            info={
                "max_cost": self.max_cost,
            },
        )
            
        return outputs