from typing import Dict
import torch
import torch.nn.functional as F
from pathlib import Path
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from .base import RouterBase
from ..pipeline_factory import PipelineLevel
from ..core.config import Config
from ..sft.instruction_templates import PipelineClassificationTemplates


class RoBERTaClassifierRouter(RouterBase):

    def __init__(self, name: str = "RoBERTaClassifierRouter", model_path: str = None, seed: int = 42):
        super().__init__(name)
        self.config = Config()

        self.model_path = Path(model_path) if model_path else self.config.roberta_save_dir / "final_model_roberta"
        self.templates = PipelineClassificationTemplates()

        self._set_seed(seed)

        self._load_model_and_tokenizer()
        
    def _set_seed(self, seed: int):
        import random
        import numpy as np
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    def _load_model_and_tokenizer(self):
        if not self.model_path.exists():
            raise FileNotFoundError(f"Model file not found: {self.model_path}")

        self.tokenizer = RobertaTokenizer.from_pretrained(
            self.config.roberta_dir, 
            padding_side="right"
        )

        self.model = RobertaForSequenceClassification.from_pretrained(
            self.model_path,
            num_labels=3,  
            problem_type="single_label_classification"
        )

        self.model = self.model.cuda()
        self.model.eval()

        for module in self.model.modules():
            if isinstance(module, (torch.nn.Dropout, torch.nn.LayerNorm)):
                module.eval()
                
    def _predict(self, question: str, schema: dict) -> tuple[int, dict]:

        input_text = self.templates.create_classifier_prompt(question, schema)

        inputs = self.tokenizer(
            input_text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )

        inputs = {k: v.cuda() for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits

            probs = F.softmax(logits, dim=-1)
            predicted_class = torch.argmax(probs, dim=-1).item()

            probabilities = {
                "basic": float(probs[0][0]),
                "intermediate": float(probs[0][1]),
                "advanced": float(probs[0][2])
            }

        return predicted_class + 1, probabilities
        
    async def route(self, query: str, schema_linking_output: Dict, query_id: str) -> str:

        linked_schema = schema_linking_output.get("linked_schema", {})

        predicted_class, probabilities = self._predict(query, linked_schema)

        self.logger.info(f"Pipeline selection probabilities for query {query_id}:")
        for pipeline, prob in probabilities.items():
            self.logger.info(f"  {pipeline}: {prob:.4f}")

        if predicted_class == 1:  
            return PipelineLevel.BASIC.value
        elif predicted_class == 2:  
            return PipelineLevel.INTERMEDIATE.value
        elif predicted_class == 3:  
            return PipelineLevel.ADVANCED.value
        else:
            raise ValueError(f"Invalid prediction from classifier: {predicted_class}") 