import torch
import torch.nn.functional as F
from typing import Dict
from pathlib import Path
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from peft import PeftModel
from .base import RouterBase
from ..pipeline_factory import PipelineLevel
from ..core.config import Config
from ..sft.instruction_templates import PipelineClassificationTemplates

class RoBERTaCascadeClassifier:
    def __init__(self, pipeline_type: str, model_name: str, model_path: Path, confidence_threshold: float = 0.5):
        self.config = Config()
        self.pipeline_type = pipeline_type
        self.model_path = model_path / f"final_model_{model_name}"
        self.confidence_threshold = confidence_threshold
        self.templates = PipelineClassificationTemplates()

        self._load_model_and_tokenizer()
        
    def _load_model_and_tokenizer(self):
        if not self.model_path.exists():
            raise FileNotFoundError(f"Model file not found: {self.model_path}")

        self.tokenizer = RobertaTokenizer.from_pretrained(
            self.config.roberta_dir,
            padding_side="right"
        )

        base_model = RobertaForSequenceClassification.from_pretrained(
            self.config.roberta_dir,
            num_labels=2,  
            problem_type="single_label_classification"
        )

        self.model = PeftModel.from_pretrained(
            base_model,
            self.model_path,
            torch_dtype=torch.float16
        )

        self.model = self.model.cuda()
        self.model.eval()

        for module in self.model.modules():
            if isinstance(module, (torch.nn.Dropout, torch.nn.LayerNorm)):
                module.eval()
                
    def classify(self, question: str, schema: dict) -> tuple[bool, float]:

        input_text = self.templates.create_classifier_prompt(question, schema)

        inputs = self.tokenizer(
            input_text,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )

        inputs = {k: v.cuda() for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits

            probs = F.softmax(logits, dim=-1)
            positive_prob = float(probs[0][1])

            can_handle = (positive_prob >= self.confidence_threshold)
            
        return can_handle, positive_prob

class RoBERTaCascadeRouter(RouterBase):

    def __init__(
        self,
        name: str = "RoBERTaCascadeRouter",
        confidence_threshold: float = 0.5,
        model_path: str = None,
        seed: int = 42
    ):
        super().__init__(name)
        self.config = Config()
        self.model_path = Path(model_path) if model_path else self.config.cascade_roberta_save_dir

        self._set_seed(seed)

        self.basic_classifier = RoBERTaCascadeClassifier(
            pipeline_type="basic",
            model_name="basic_classifier",
            model_path=self.model_path,
            confidence_threshold=confidence_threshold
        )
        print("[i] Successfully load and initialize classifier for basic pipeline")
        
        self.intermediate_classifier = RoBERTaCascadeClassifier(
            pipeline_type="intermediate",
            model_name="intermediate_classifier",
            model_path=self.model_path,
            confidence_threshold=confidence_threshold
        )
        print("[i] Successfully load and initialize classifier for intermediate pipeline")

    def _set_seed(self, seed: int):
        import random
        import numpy as np
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    async def route(self, query: str, schema_linking_output: Dict, query_id: str) -> str:
        linked_schema = schema_linking_output.get("linked_schema", {})

        can_handle, confidence = self.basic_classifier.classify(query, linked_schema)
        if can_handle:
            self.logger.info(f"Query {query_id} routed to BASIC pipeline (confidence: {confidence:.3f})")
            return PipelineLevel.BASIC.value

        can_handle, confidence = self.intermediate_classifier.classify(query, linked_schema)
        if can_handle:
            self.logger.info(f"Query {query_id} routed to INTERMEDIATE pipeline (confidence: {confidence:.3f})")
            return PipelineLevel.INTERMEDIATE.value

        self.logger.info(
            f"Query {query_id} routed to ADVANCED pipeline (confidence: {confidence:.3f})"
        )
        return PipelineLevel.ADVANCED.value
        