# services/enhanced_query_planner.py
"""
Enhanced query decomposition planner.
Decomposes a question into sub-queries with importance scores.
"""

import os
import sys
import json
import re
import time
from typing import Dict, List, Any, Tuple
from dataclasses import dataclass
from pathlib import Path

from services.local_judge import get_judge_model
from services.api_judge import get_api_judge_model, create_api_judge_model
from config.settings import SCOPE_CONFIG, QUERY_DECOMPOSITION_CONFIG, API_CONFIG
from config.prompts import SCOPE_PROMPT


@dataclass
class QueryNode:
    """Query node."""
    id: str
    text: str
    importance: int  # 1-10 importance score
    depends_on: List[str]
    layer: int


@dataclass
class QueryGraph:
    """Query graph."""
    nodes: List[QueryNode]
    edges: List[Dict[str, Any]]
    total_importance: int
    decomposition_time: float
    model_diag: Dict[str, Any]


class EnhancedQueryPlanner:
    """Enhanced query decomposition planner."""
    
    def __init__(self, judge_model=None, use_api=None, api_key=None, api_model_name=None):
        """
        Initialize planner.

        Args:
            judge_model: Custom judge model instance (optional).
            use_api: Whether to use a remote API backend (None uses config default).
            api_key: API key (only used when use_api=True).
            api_model_name: Remote model name (optional).
        """
        self.config = SCOPE_CONFIG
        self.api_model_name = api_model_name
        
        if judge_model is not None:
            # Use the provided model
            self.judge_model = judge_model
            self.model_type = "custom"
        else:
            # Choose backend from config
            use_api_flag = use_api if use_api is not None else QUERY_DECOMPOSITION_CONFIG.get("use_api", True)
            
            if use_api_flag:
                # Remote API backend
                try:
                    self.judge_model = create_api_judge_model(api_key=api_key, model_name=api_model_name)
                    self.model_type = "api"
                    print("[EnhancedQueryPlanner] Using API backend for query decomposition")
                except Exception as e:
                    print(f"[EnhancedQueryPlanner] Failed to initialize API backend: {e}")
                    print("[EnhancedQueryPlanner] Falling back to local model")
                    self.judge_model = get_judge_model()
                    self.model_type = "local_fallback"
            else:
                # Local backend
                self.judge_model = get_judge_model()
                self.model_type = "local"
                print("[EnhancedQueryPlanner] Using local backend for query decomposition")
    
    def decompose_query_with_importance(self, question: str) -> QueryGraph:
        """
        Decompose a question into sub-queries with importance scores.

        Args:
            question: Input question text.

        Returns:
            QueryGraph with nodes/edges and importance scores.
        """
        start_time = time.time()
        
        # Build decomposition prompt with importance requirement
        enhanced_prompt = self._build_enhanced_decomposition_prompt(question)
        
        # Retry loop
        max_retries = 3
        errors = []
        
        for attempt in range(max_retries):
            try:
                print(f"[EnhancedQueryPlanner] Decomposition attempt {attempt + 1}/{max_retries}")
                response = self.judge_model.generate_response(enhanced_prompt, max_new_tokens=1024)
                
                # Parse JSON response with importance
                result = self._parse_enhanced_response(response, question)
                
                # Return if parsing succeeded (no fallback)
                if not result["_diag"]["fallback"]:
                    print(f"[EnhancedQueryPlanner] Decomposition succeeded on attempt {attempt + 1}")
                    
                    # Build QueryGraph
                    decomposition_time = time.time() - start_time
                    
                    query_graph = QueryGraph(
                        nodes=result["nodes"],
                        edges=result["edges"],
                        total_importance=sum(node.importance for node in result["nodes"]),
                        decomposition_time=decomposition_time,
                        model_diag={
                            "attempts": attempt + 1,
                            "retry_errors": errors,
                            "final_success": True,
                            "planner_backend": self.model_type,
                            "planner_api_model": self.api_model_name,
                        }
                    )
                    
                    return query_graph
                else:
                    # Parsing failed: record error and retry
                    parse_errors = result['_diag']['errors']
                    if parse_errors:
                        # Flatten errors
                        actual_error = parse_errors[0] if isinstance(parse_errors, list) else str(parse_errors)
                        error_msg = f"Attempt {attempt + 1}: {actual_error}"
                    else:
                        error_msg = f"Attempt {attempt + 1}: invalid JSON format"
                    
                    errors.append(error_msg)
                    print(f"[EnhancedQueryPlanner] {error_msg}; retrying...")
                    continue
                    
            except Exception as e:
                error_msg = f"Attempt {attempt + 1}: model call failed - {str(e)}"
                errors.append(error_msg)
                print(f"[EnhancedQueryPlanner] {error_msg}; retrying...")
                continue
        
        # All retries failed: return fallback
        print(f"[EnhancedQueryPlanner] Decomposition failed after {max_retries} attempts; using fallback")
        
        fallback_node = QueryNode(
            id="Q1", 
            text=question, 
            importance=self.config["frame_allocation"]["default_importance"],
            depends_on=[], 
            layer=0
        )
        
        return QueryGraph(
            nodes=[fallback_node],
            edges=[],
            total_importance=fallback_node.importance,
            decomposition_time=time.time() - start_time,
            model_diag={
                "attempts": max_retries,
                "retry_errors": errors,
                "final_success": False,
                "fallback_used": True,
                "planner_backend": self.model_type,
                "planner_api_model": self.api_model_name,
            }
        )
    
    def _build_enhanced_decomposition_prompt(self, question: str) -> str:
        """Build decomposition prompt."""
        return SCOPE_PROMPT.format(question=question)
    
    def _parse_enhanced_response(self, response: str, fallback_question: str) -> Dict[str, Any]:
        """Parse response and extract importance scores."""
        try:
            # Extract JSON
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                json_str = json_match.group()
                raw_result = json.loads(json_str)
                
                # Convert to QueryNode objects
                nodes = []
                for node_data in raw_result.get("nodes", []):
                    node = QueryNode(
                        id=node_data["id"],
                        text=node_data["text"],
                        importance=int(node_data.get("importance", self.config["frame_allocation"]["default_importance"])),
                        depends_on=node_data.get("depends_on", []),
                        layer=int(node_data.get("layer", 0))
                )
                    nodes.append(node)
                
                # Validate importance range
                for node in nodes:
                    if not (1 <= node.importance <= 10):
                        node.importance = self.config["frame_allocation"]["default_importance"]
                
                # Topological sorting to ensure correct layers
                nodes = self._topo_layers(nodes)
                edges = self._clean_edges(raw_result.get("edges", []), nodes)
                
                return {
                    "nodes": nodes,
                    "edges": edges,
                    "_diag": {
                        "errors": [],
                        "fallback": False
                    }
                }
            else:
                raise ValueError("No JSON found in response")
                
        except (json.JSONDecodeError, ValueError, KeyError) as e:
            print(f"[EnhancedQueryPlanner] JSON parse error: {e}")
            
            # Fallback: single node
            fallback_node = QueryNode(
                id="Q1", 
                text=fallback_question, 
                importance=self.config["frame_allocation"]["default_importance"],
                depends_on=[], 
                layer=0
            )
            
            return {
                "nodes": [fallback_node],
                "edges": [],
                "_diag": {
                    "errors": [str(e)],
                    "fallback": True
                }
            }
    
    def _topo_layers(self, nodes: List[QueryNode]) -> List[QueryNode]:
        """Topological sort and assign layers."""
        id2node = {n.id: n for n in nodes}
        indeg = {n.id: 0 for n in nodes}
        
        for n in nodes:
            for d in n.depends_on:
                if d in indeg:
                    indeg[n.id] += 1

        # Kahn's algorithm + layering
        from collections import deque, defaultdict
        q = deque([nid for nid, d in indeg.items() if d == 0])
        layer = defaultdict(int)
        order = []
        
        while q:
            nid = q.popleft()
            order.append(nid)
            
            # Update neighbors
            for m in nodes:
                if nid in m.depends_on:
                    indeg[m.id] -= 1
                    layer[m.id] = max(layer[m.id], layer[nid] + 1)
                    if indeg[m.id] == 0:
                        q.append(m.id)

        # Apply layers
        for n in nodes:
            n.layer = int(layer[n.id])
        
        # Sort by layer and id
        nodes.sort(key=lambda x: (x.layer, x.id))
        return nodes
    
    def _clean_edges(self, edges: List[Dict], nodes: List[QueryNode]) -> List[Dict]:
        """Normalize and filter edges."""
        ids = {n.id for n in nodes}
        out = []
        MIN_AROUND_WIN = 20  # minimum AROUND window
        
        for e in edges or []:
            p, c = e.get("parent"), e.get("child")
            if p in ids and c in ids:
                typ = str(e.get("temporal", "AFTER")).upper()
                if typ not in ("AFTER", "BEFORE", "AROUND"):
                    typ = "AFTER"
                scope = str(e.get("scope", "all")).lower()
                if scope not in ("all", "any"):
                    scope = "all"
                ws = int(e.get("window_s", 0))
                if typ == "AROUND" and ws <= 0:
                    ws = MIN_AROUND_WIN
                out.append({"parent": p, "child": c, "temporal": typ, "window_s": ws, "scope": scope})
        
        return out
    
    def allocate_frames_by_importance(self, query_graph: QueryGraph, total_budget: int) -> Dict[str, int]:
        """
        Allocate per-query frame budget using semantic importance.
        Uses weighted normalization with a minimum-per-query guarantee.

        Args:
            query_graph: Query graph.
            total_budget: Total frame budget.

        Returns:
            Mapping from node id to allocated frames.
        """
        N = len(query_graph.nodes)
        
        if N == 0:
            return {}
        
        if total_budget < N:
            print(f"[EnhancedQueryPlanner][WARN] Budget too small ({total_budget} < {N}); using even split")
            base_frames = total_budget // N
            remainder = total_budget % N
            allocation = {}
            for i, node in enumerate(query_graph.nodes):
                allocation[node.id] = base_frames + (1 if i < remainder else 0)
            return allocation
        
        # Step 1: minimum guarantee
        min_frames_per_query = self.config["frame_allocation"]["min_frames_per_query"]
        base_allocation = {node.id: min_frames_per_query for node in query_graph.nodes}
        
        # Step 2: remaining budget
        B_rem = total_budget - N * min_frames_per_query
        
        if B_rem <= 0:
            return base_allocation
        
        # Step 3: importance weights
        S_sum = sum(node.importance for node in query_graph.nodes)
        if S_sum == 0:
            # Even split of remaining budget
            avg_extra = B_rem // N
            remainder = B_rem % N
            for i, node in enumerate(query_graph.nodes):
                base_allocation[node.id] += avg_extra + (1 if i < remainder else 0)
            return base_allocation
        
        # Step 4: allocate remaining budget (float)
        float_allocations = {}
        for node in query_graph.nodes:
            weight = node.importance / S_sum
            float_allocations[node.id] = B_rem * weight
        
        # Step 5: integerize
        # 5.1 allocate integer parts
        int_allocations = {node_id: int(float_val) for node_id, float_val in float_allocations.items()}
        
        # 5.2 remaining frames
        assigned_extra = sum(int_allocations.values())
        final_remainder = B_rem - assigned_extra
        
        # 5.3 distribute by fractional part
        if final_remainder > 0:
            fractional_parts = []
            for node_id, float_val in float_allocations.items():
                fractional_part = float_val - int_allocations[node_id]
                fractional_parts.append((fractional_part, node_id))
            
            # Sort descending
            fractional_parts.sort(key=lambda x: x[0], reverse=True)
            
            # Assign remaining frames
            for i in range(final_remainder):
                _, node_id = fractional_parts[i]
                int_allocations[node_id] += 1
        
        # Step 6: final allocation
        final_allocation = {}
        for node in query_graph.nodes:
            final_allocation[node.id] = base_allocation[node.id] + int_allocations[node.id]
        
        # Sanity check
        total_allocated = sum(final_allocation.values())
        assert total_allocated == total_budget, f"Allocation mismatch: {total_allocated} != {total_budget}"
        
        return final_allocation


# Global instance (lazy initialization)
_global_enhanced_planner = None


def get_enhanced_planner(use_api=None, api_key=None) -> EnhancedQueryPlanner:
    """
    Get a global planner instance.

    Args:
        use_api: Whether to use remote API backend (None uses config default).
        api_key: API key (only used when use_api=True).
    """
    global _global_enhanced_planner
    if _global_enhanced_planner is None:
        _global_enhanced_planner = EnhancedQueryPlanner(use_api=use_api, api_key=api_key)
    return _global_enhanced_planner


def create_enhanced_planner(use_api=None, api_key=None, api_model_name=None) -> EnhancedQueryPlanner:
    """
    Create a new planner instance.

    Args:
        use_api: Whether to use remote API backend (None uses config default).
        api_key: API key (only used when use_api=True).
        api_model_name: Remote model name (optional).
    """
    return EnhancedQueryPlanner(use_api=use_api, api_key=api_key, api_model_name=api_model_name)


def build_enhanced_query_graph(question: str, use_api=None, api_key=None, api_model_name=None) -> QueryGraph:
    """
    Build an enhanced query graph (with importance scores).

    Args:
        question: Input question text.
        use_api: Whether to use remote API backend (None uses config default).
        api_key: API key (only used when use_api=True).
        api_model_name: Remote model name (optional).

    Returns:
        QueryGraph.
    """
    # If a custom API config is passed, create a new instance instead of using the global one.
    if use_api is not None or api_key is not None or api_model_name is not None:
        planner = create_enhanced_planner(use_api=use_api, api_key=api_key, api_model_name=api_model_name)
    else:
        planner = get_enhanced_planner()
    
    return planner.decompose_query_with_importance(question)
