import os
import pickle
import torch
import numpy as np
from sklearn.linear_model import Ridge

from llm_router.router.base_router import BaseRouter, RouterInput, RouterOutput, PREFERENCES
from llm_router.data.utils import get_costs

class LinearPredictorRouter(BaseRouter):
    
    @property
    def name(self):
        return f"linear_predictor_router"
    
    def fit(self, trainset, valset, configs):
        self.benchmark = trainset.get_benchmark()
        self.num_models = self.benchmark["scores"].shape[1]
        
        if os.path.exists(os.path.join(configs.training.output_dir, "linear.pkl")):
            with open(os.path.join(configs.training.output_dir, "linear.pkl"), "rb") as f:
                self.mlp = pickle.load(f)
            return
        
        mlp = []
        for i in range(self.num_models):
            score_mlp = Ridge().fit(
                self.benchmark["embeddings"].cpu().numpy(),
                self.benchmark["scores"][:, i].cpu().numpy()
            )    
            token_mlp = Ridge().fit(
                self.benchmark["embeddings"].cpu().numpy(),
                self.benchmark["output_tokens"][:, i].cpu().numpy(),
            )
            mlp.append({"score": score_mlp, "token": token_mlp})
        self.mlp = mlp
        
        os.makedirs(configs.training.output_dir, exist_ok=True)
        with open(os.path.join(configs.training.output_dir, "linear.pkl"), "wb") as f:
            pickle.dump(self.mlp, f)
            
    def route(self, router_input: RouterInput):
        predicted_scores = []
        predicted_output_tokens = []
        for i in range(self.num_models):
            predicted_scores.append(self.mlp[i]["score"].predict(router_input.embedding.unsqueeze(0).cpu().numpy())[0])
            predicted_output_tokens.append(self.mlp[i]["token"].predict(router_input.embedding.unsqueeze(0).cpu().numpy())[0])
        predicted_scores = torch.tensor(predicted_scores)
        predicted_output_tokens = torch.tensor(predicted_output_tokens)
        predicted_costs = get_costs(router_input.input_tokens, predicted_output_tokens, router_input.routing_config)
        assert len(predicted_scores) == len(predicted_costs)
        
        outputs = {}
        for preference in PREFERENCES:
            q = [s - preference * c for s, c in zip(predicted_scores, predicted_costs)]
            routing_id = np.argmax(q)
            routing_model = list(router_input.routing_config.keys())[routing_id]
            outputs[preference] = RouterOutput(
                idx=router_input.idx,
                routing_config=router_input.routing_config,
                scores=router_input.scores,
                costs=router_input.costs,
                input_tokens=router_input.input_tokens,
                output_tokens=router_input.output_tokens,
                routing_id=routing_id,
                routing_model=routing_model,
                info={
                    "predicted_scores": predicted_scores,
                    "predicted_output_tokens": predicted_output_tokens,
                    "predicted_costs": predicted_costs,
                },
            )
            
        return outputs