from abc import ABC
from typing import Union
from dataclasses import dataclass, field
import torch

@dataclass
class RouterInput:
    idx: int
    routing_config: dict[str, dict[str, float]]
    prompt: Union[str, list[dict]]
    embedding: torch.Tensor
    scores: list[float]
    costs: list[float]
    input_tokens: list[float]
    output_tokens: list[float]
    
@dataclass
class RouterOutput:
    idx: int
    routing_config: dict[str, dict[str, float]]
    scores: list[float]
    costs: list[float]
    input_tokens: list[float]
    output_tokens: list[float]
    routing_id: int
    routing_model: str
    info: dict
    
PREFERENCES = [
    0.0,
    0.0000000001,
    0.000000001,
    0.00000001,
    0.0000001,
    0.000001,
    0.00001,
    0.0001,
    0.001,
    0.005,
    0.01,
    0.1,
    0.5,
    1.0,
    1.5,
    2.0,
    10.0,
    50.0,
    100.0,
    500.0,
    1000.0,
    2000.0,
    3000.0,
    4000.0,
    5000.0,
    7500.0,
    10000.0,
    25000.0,
    50000.0,
    75000.0,
    100000.0,
    1000000.0,
    1000000000.0,
]
    
class BaseRouter(ABC):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        
    @property
    def name(self):
        return "base_router"
    
    def fit(self, trainset, valset, configs):
        raise NotImplementedError()
    
    def route(self, router_input: RouterInput):
        raise NotImplementedError()