"""
Direct PRM scorer that bypasses the stat calculator pipeline for efficient stepwise scoring
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Any, Optional
import logging
from transformers import AutoTokenizer, AutoModel

from lm_polygraph import WhiteboxModel
from synthetic_dataset_generation.utils.steps_extractor import StepsExtractor
from baselines.prm import load_prm_calculator_by_model_path
from .base import RewardBasedScorer

log = logging.getLogger(__name__)


class DirectPRMScorer(RewardBasedScorer):
    """
    Direct PRM scorer that applies Process Reward Model without stat calculator pipeline.
    
    This implementation:
    1. Uses the factory function to load the appropriate PRM calculator
    2. Extracts claims/steps from candidates
    3. Delegates to the calculator's get_rewards method
    4. Returns reward scores (higher = better)
    
    Supports all PRM models that baselines supports: Qwen, RLHFlow, MathShepherd, Skywork, GenPRM
    """
    
    def __init__(
        self,
        model: WhiteboxModel,
        prm_model_path: str = "Qwen/Qwen2.5-Math-7B-PRM800K",
        device: str = "cuda",
        batch_size: int = 8,
        prompt_template: str = None
    ):
        super().__init__("DirectPRM")
        self.model = model
        self.prm_model_path = prm_model_path
        self.device = device
        self.batch_size = batch_size
        self.prompt_template = prompt_template or "{q}"
        self.steps_extractor = StepsExtractor(progress_bar=False)
        
        # Use the factory function to get the appropriate PRM calculator
        # We don't pass prompt_path since we handle prompts separately
        self.prm_calculator = None
        
    def prepare_model(self):
        """Initialize PRM calculator if not already loaded"""
        if self.prm_calculator is None:
            log.info(f"Loading PRM calculator for {self.prm_model_path}")
            self.prm_calculator = load_prm_calculator_by_model_path(
                prompt_path=None,  # We handle prompts separately
                model_path=self.prm_model_path,
                device=self.device
            )
            # Initialize the model
            self.prm_calculator.init()
            
    def cleanup(self):
        """Free PRM model memory"""
        if self.prm_calculator is not None:
            # Some calculators might have cleanup methods
            if hasattr(self.prm_calculator, 'cleanup'):
                self.prm_calculator.cleanup()
            # Clear references
            if hasattr(self.prm_calculator, 'model'):
                self.prm_calculator.model = None
            if hasattr(self.prm_calculator, 'prm_model'):
                self.prm_calculator.prm_model = None
            if hasattr(self.prm_calculator, 'tokenizer'):
                self.prm_calculator.tokenizer = None
            if hasattr(self.prm_calculator, 'prm_tokenizer'):
                self.prm_calculator.prm_tokenizer = None
            self.prm_calculator = None
            torch.cuda.empty_cache()
    
    def compute_claim_rewards(
        self,
        trajectory: str,
        candidates: List[str],
        **kwargs
    ) -> List[List[float]]:
        """
        Compute reward scores for claims in each candidate.
        
        Args:
            trajectory: Current trajectory text
            candidates: List of candidate next steps
            
        Returns:
            List of claim reward lists (one per candidate)
        """
        self.prepare_model()
        
        if not candidates:
            return []
        
        # For PRM, the "question" is the entire trajectory up to this point
        # This gives the PRM full context to evaluate the next step
        
        # Score all candidates
        all_rewards = []
        
        for candidate in candidates:
            # Handle empty candidates
            if not candidate or not candidate.strip():
                log.debug("Empty candidate, returning neutral score 0.0")
                all_rewards.append([0.0])
                continue
                
            try:
                rewards = self._score_single_candidate(trajectory, candidate)
                all_rewards.append(rewards)
            except Exception as e:
                log.warning(f"Failed to score candidate: {e}")
                all_rewards.append([0.0])  # Neutral reward
            
            # Clean up memory after each candidate
            torch.cuda.empty_cache()
        
        return all_rewards
    
    def _extract_question(self, trajectory: str) -> str:
        """Extract the original question from the trajectory"""
        # Look for common patterns that indicate end of question
        end_patterns = [
            "Reasoning Steps:",
            "Solution:",
            "Answer:",
            "\n\n",
            "- Step"
        ]
        
        question = trajectory
        for pattern in end_patterns:
            if pattern in trajectory:
                parts = trajectory.split(pattern)
                if parts[0].strip():
                    question = parts[0].strip()
                    break
        
        # Remove any system prompts if present
        if "<|im_start|>" in question:
            # Extract content between user tags
            start = question.find("<|im_start|>user")
            end = question.find("<|im_end|>", start)
            if start != -1 and end != -1:
                question = question[start+len("<|im_start|>user"):end].strip()
        
        return question
    
    def _score_single_candidate(
        self, 
        trajectory: str,
        candidate: str
    ) -> List[float]:
        """Score a single candidate using PRM
        
        Args:
            trajectory: The full trajectory up to this point (provides context)
            candidate: The candidate next step to evaluate
        """
        
        # Extract claims from candidate only
        try:
            candidate_tokens = self.model.tokenize([candidate])
            if candidate_tokens is None or 'input_ids' not in candidate_tokens:
                log.warning(f"Failed to tokenize candidate: {candidate[:50]}...")
                return [0.0]  # Neutral score
                
            claims = self.steps_extractor.split_to_steps(
                candidate,
                candidate_tokens['input_ids'][0],
                self.model.tokenizer
            )
            
            if not claims:
                log.debug(f"No claims extracted from candidate: {candidate[:50]}...")
                return [0.0]  # Neutral score for empty/no claims
                
        except Exception as e:
            log.warning(f"Error extracting claims: {e}")
            return [0.0]  # Neutral score
        
        # Get PRM rewards - pass trajectory as the "question" for full context
        try:
            rewards = self._compute_prm_rewards(trajectory, claims)
            return rewards if rewards else [0.0]  # Neutral score
        except Exception as e:
            log.warning(f"Error computing PRM rewards: {e}")
            return [0.0]  # Neutral score
    
    def _compute_prm_rewards(self, trajectory: str, claims: List[Any]) -> List[float]:
        """Compute PRM rewards for claims using the appropriate calculator
        
        Args:
            trajectory: The full trajectory/context up to this point
            claims: List of claims extracted from the candidate step
        """
        
        if not claims:
            return []
        
        # Simply delegate to the calculator's get_rewards method
        # The trajectory serves as the "question" - providing full context for evaluation
        try:
            rewards = self.prm_calculator.get_rewards(trajectory, claims)
            return rewards
        except Exception as e:
            log.warning(f"Error in PRM calculator get_rewards: {e}")
            # Return neutral scores on error
            return [0.0] * len(claims)


# class DirectPRMScorerOptimized(DirectPRMScorer):
#     """
#     Optimized version with better batching for multiple candidates.
    
#     Additional optimizations:
#     1. Batch multiple candidates together when possible
#     2. Cache question extraction
#     3. Reuse tokenization results
#     """
    
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.question_cache = {}
#         self.max_cache_size = 100
        
#     def compute_claim_rewards(
#         self,
#         trajectory: str,
#         candidates: List[str],
#         **kwargs
#     ) -> List[List[float]]:
#         """Compute rewards with optimized batching"""
        
#         self.prepare_model()
        
#         if not candidates:
#             return []
        
#         # Get question (with caching)
#         trajectory_hash = hash(trajectory[:200])  # Hash prefix for stability
#         if trajectory_hash in self.question_cache:
#             question = self.question_cache[trajectory_hash]
#         else:
#             question = self._extract_question(trajectory)
#             # Cache with size limit
#             if len(self.question_cache) >= self.max_cache_size:
#                 self.question_cache.pop(next(iter(self.question_cache)))
#             self.question_cache[trajectory_hash] = question
        
#         # Process in batches for efficiency
#         all_rewards = []
#         for i in range(0, len(candidates), self.batch_size):
#             batch_candidates = candidates[i:i + self.batch_size]
#             batch_rewards = self._score_batch(question, trajectory, batch_candidates)
#             all_rewards.extend(batch_rewards)
        
#         return all_rewards
    
#     def _score_batch(
#         self,
#         question: str,
#         trajectory: str,
#         candidates: List[str]
#     ) -> List[List[float]]:
#         """Score a batch of candidates"""
#         # For now, fall back to individual scoring
#         # (PRM batching would require careful handling of different claim counts)
#         rewards = []
#         for candidate in candidates:
#             rewards.append(self._score_single_candidate(question, trajectory, candidate))
#         return rewards
    
#     def cleanup(self):
#         """Clean up resources including cache"""
#         self.question_cache.clear()
#         super().cleanup()