from typing import Dict
from pathlib import Path
from .base import RouterBase
from ..pipeline_factory import PipelineLevel
from ..core.config import Config
from sklearn.neighbors import KNeighborsClassifier
from ..core.utils import load_json, load_jsonl

def extract_features(data):
    features = []
    labels = []
    for item in data:
        schema_info = item["enhanced_linked_schema_wo_info"]
        tables = schema_info["tables"]
        num_tables = len(tables)
        num_columns = sum(len(table["columns"]) for table in tables)
        
        features.append([num_tables, num_columns])
        label = item["pipeline_type"]
        if label in ["ADVANCED", "UNSOLVED"]:
            label = "ADV_UNSOLVED"
        labels.append(label)
    
    return features, labels

class KNNClassifierRouter(RouterBase):

    def __init__(self, name: str = "KNNClassifierRouter", seed: int = 42, train_file_path: str = None):
        super().__init__(name)
        self.config = Config()
        self.train_file_path = Path(train_file_path) if train_file_path else Path("data/labeled/bird_train_pipeline_label.jsonl")
        self.train_data = load_jsonl(self.train_file_path)
        X_train, y_train = extract_features(self.train_data)

        print("Training the KNN classifier...")
        self.knn_classifier = KNeighborsClassifier(n_neighbors=5)
        self.knn_classifier.fit(X_train, y_train)
        print("Training Finished")
        

    def _predict(self, question: str, schema: dict) -> tuple[int, dict]:

        num_tables = len(schema["tables"])
        num_columns = sum(len(table["columns"]) for table in schema["tables"])
        
        feature_vector = [[num_tables, num_columns]]
        ans = self.knn_classifier.predict(feature_vector)[0]

        if(ans == "BASIC"):
            return 0
        elif(ans == "INTERMEDIATE"):
            return 1
        else:
            return 2
        
    async def route(self, query: str, schema_linking_output: Dict, query_id: str) -> str:
        linked_schema = schema_linking_output.get("linked_schema", {})

        predicted_class = self._predict(query, linked_schema)

        if predicted_class == 0:  
            return PipelineLevel.BASIC.value
        elif predicted_class == 1:  
            return PipelineLevel.INTERMEDIATE.value
        elif predicted_class == 2:  
            return PipelineLevel.ADVANCED.value
        else:
            raise ValueError(f"Invalid prediction from classifier: {predicted_class}") 