# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE/2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Chain of Thought (CoT) Quality Scoring Module

This module provides functionality to score the quality of reasoning processes
in mathematical problem-solving responses using LLM-based annotation.
"""

import re
import time
import logging
from typing import Dict, Any, Optional, Tuple, List
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
import os
import random
# Try to import OpenAI client, but make it optional
try:
    from verl.utils.openai_utils import OpenaiClient
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False
    raise ImportError("openai_utils not available. CoT quality scoring will be disabled.")
    # print("Warning: openai_utils not available. CoT quality scoring will be disabled.")

# Default prompt template for CoT quality assessment
DEFAULT_COT_PROMPT = """Read the problem and the corresponding reasoning process that reaches the correct answer. Judge the quality of the reasoning process. A high-quality reasoning process should be logically coherent and consistent.

[Problem]
{problem}

[Reasoning Process]
{reasoning_process}

Output an integer score between 0 and 5 to indicate quality. 0 means the reasoning process is completely wrong, 5 means the reasoning process is perfect. Your output should contain two lines: the first line is the score, the second line is the justification for the score."""

ERROR_TYPE_PROMPT = """Your task is to analyze the provided mathematical reasoning, identify the primary error, and classify it.

**Error Categories & Examples**
1. Misapplication of Math Concepts
Description: Incorrectly stating or applying a mathematical rule, theorem, or formula.
Example: "Using the Pythagorean theorem, a² + b² = c², on a triangle that is not a right triangle."

2. Logical Leaps / Unjustified Claims
Description: The reasoning contains logical leaps or makes "correct" claims that are not justified step by step.
Example: "From the diagram, we can see that triangle ABC is isosceles, so AB = BC." (This is not proven).

3. Incomplete Analysis / Neglected Cases
Description: The reasoning identifies the correct answer but fails to consider all possible scenarios or conditions required by the problem, thus making the analysis incomplete.
Example: "If x² = 25, then x must be 5." (Ignores the x = -5 solution).

4. Calculation / Factual Error
Description: Mistakes in arithmetic, algebraic manipulation, or recalling a known mathematical fact.
Example: "Simplifying the equation 2x + 1 = 5 leads to 2x = 6, so x = 3." (Incorrect subtraction).

5. Incorrect Problem Modeling
Description: The initial setup or interpretation of the problem is flawed, causing the model to solve the wrong problem.
Example: "The problem asks for the perimeter, so we calculate the area by multiplying length and width."

6. Failure to Adhere to Constraints
Description: The solution or intermediate steps violate one or more conditions given in the problem statement.
Example: "The problem states x must be an integer, but the final answer is x = 2.5."

7. Incorrect/Nonsensical Justification
Description: The justification provided is incorrect or nonsensical.
Example: The reasoning is hand-wavy and illogical. It introduces a confusing, incorrect step about "subtracting segments" before concluding without proper geometric basis that one can "simply add the circumference."

**Output Format**
Error Category: The name of the category that best describes the error
Justification: A brief, one-sentence explanation of the specific mistake and why it fits the chosen category.

**Note**: If the reasoning is **completely correct**, output "Correct". Or if the reasoning contains incorrect steps but does not fall into any of the categories, output "Other". If there are multiple error types, separate them with a comma.

Below is the problem and reasoning:
problem: {problem}
reasoning: {reasoning_process}

Output:
"""

err_map = {
    1: "Misapplication of Math Concepts",
    2: "Logical Leaps / Unjustified Claims",
    3: "Incomplete Analysis / Neglected Cases",
    4: "Calculation / Factual Error",
    5: "Incorrect Problem Modeling",
    6: "Failure to Adhere to Constraints",
    7: "Incorrect/Nonsensical Justification",
}

all_error_types = list(err_map.values()) +  ["Correct"]

def extract_error_type(text):
    ls = text.split(",")
    ls = [item.strip().strip(".").strip() for item in ls if item.strip()]
    # remove index:  2. Logical Leaps / Unjustified Claims -> Logical Leaps / Unjustified Claims
    # Remove numbering if present (e.g., "2. Logical Leaps" -> "Logical Leaps")
    ls = [item.split(". ", 1)[-1].strip() if ". " in item else item for item in ls]
    # map number to error type
    ls_new = []
    for item in ls:
        try:
            err_type = err_map[int(item)]
            ls_new.append(err_type)
        except:
            ls_new.append(item)
    # check ls_new is all in err_map
    if not ls_new:
        raise ValueError(f"No error type found in: {text}")
    flag = True
    for item in ls_new:
        if item in all_error_types:
            flag = False
            break
    if flag:
        print(f"[Warning] All Error types not in error list: {ls_new}")
    return ls_new
    


class CoTQualityScorer:
    """Scorer for Chain of Thought quality using LLM annotation."""
    _low_quality_threshold = 4
    _low_quality_error_type = ["misapplication of math concepts", "calculation / factual error", "incorrect problem modeling", "incorrect/nonsensical justification", "other", "incorrect / nonsensical justification"]
    _high_quality_error_type = ["correct", "logical leaps / unjustified claims", "incomplete analysis / neglected cases", "failure to adhere to constraints"]
    
    def __init__(self, 
                 model_name: str = "gpt-5-nano",
                 prompt_template: str = DEFAULT_COT_PROMPT,
                 max_workers: int = 16,
                 cache_file: Optional[str] = None,
                 enable_caching: bool = True):
        """
        Initialize the CoT quality scorer.
        
        Args:
            model_name: Name of the LLM model to use for scoring
            prompt_template: Template for the scoring prompt
            max_workers: Maximum number of parallel workers for API calls
            cache_file: Path to cache file for storing scores
            enable_caching: Whether to enable caching of scores
        """
        self.model_name = model_name
        if prompt_template == "default":
            self.prompt_template = DEFAULT_COT_PROMPT
        elif prompt_template == "error_type":
            self.prompt_template = ERROR_TYPE_PROMPT
        else:
            raise ValueError(f"Invalid prompt template: {prompt_template}")
        # self.prompt_template = prompt_template
        self.max_workers = max_workers
        self.cache_file = cache_file
        os.makedirs(os.path.dirname(cache_file), exist_ok=True)
        self.enable_caching = enable_caching
        
        # Initialize OpenAI client if available
        if OPENAI_AVAILABLE:
            self.model = OpenaiClient(model=model_name)
        else:
            self.model = None
            raise ValueError("OpenAI client not available. CoT quality scoring disabled.")
            # print("Warning: OpenAI client not available. CoT quality scoring disabled.")
        
        # Load cache if available
        self.score_cache = {}
        if self.enable_caching and self.cache_file and os.path.exists(self.cache_file):
            try:
                with open(self.cache_file, 'r') as f:
                    self.score_cache = json.load(f)
                print(f"Loaded {len(self.score_cache)} cached CoT quality scores from {self.cache_file}")
            except Exception as e:
                print(f"Warning: Failed to load cache file {self.cache_file}: {e}")
                self.score_cache = {}
    
    def _parse_response_error_type(self, response: str) -> Tuple[Optional[str], Optional[str]]:
        """
        Parse the LLM response to extract error type and justification.
        """
        try:
            lines = response.strip().split('\n')
            lines = [line.strip() for line in lines if line.strip()]
            # print(lines)
            error_type = None
            justification = None
            
            for line in lines:
                line = line.strip()
                if line.startswith('Error Category:'):
                    error_type = line.replace('Error Category:', '').strip()
                elif line.startswith('Justification:'):
                    justification = line.replace('Justification:', '').strip()
                elif line.lower() ==  'correct':
                    error_type = 'Correct'
                    # justification = 'The reasoning is completely correct.'
                elif line.lower() == 'other':
                    error_type = 'Other'
                    # justification = 'The reasoning contains incorrect steps but does not fall into any predefined categories.'
                # else:
                #     # return whatever it is
                #     error_type = line
            error_type_ls = extract_error_type(error_type)
            return error_type_ls, justification
        except Exception as e:
            logging.warning(f"Failed to parse error type response: {response}")
            # raise e
            return None, None
    
    def _parse_response_default(self, response: str) -> Tuple[Optional[int], Optional[str]]:
        """
        Parse the LLM response to extract score and justification.
        
        Args:
            response: Raw response from the LLM
            
        Returns:
            Tuple of (score, justification) or (None, None) if parsing fails
        """
        try:
            lines = response.strip().split('\n')
            if len(lines) < 2:
                print(f"Failed to parse CoT quality response: {response}")
                return None, None
                
            score_line = lines[0].strip()
            justification = lines[1].strip()
            
            # Extract score from the first line
            score_match = re.search(r'(\d+)', score_line)
            if not score_match:
                print(f"Failed to parse CoT quality response: {response}")
                return None, None
                
            score = int(score_match.group(1))
            
            # Validate score range
            if score < 0 or score > 5:
                print(f"Failed to parse CoT quality response: {response}")
                return None, None
                
            return score, justification
            
        except Exception as e:
            logging.warning(f"Failed to parse CoT quality response: {e}")
            return None, None
    
    def _parse_response(self, response: str) -> Tuple[Optional[List[str]], Optional[str]]:
        if self.prompt_template == ERROR_TYPE_PROMPT:
            return self._parse_response_error_type(response)
        else:
            return self._parse_response_default(response)
    
    def _create_cache_key(self, problem: str, reasoning_process: str) -> str:
        """Create a cache key for the problem-reasoning pair."""
        # Use a simple hash of the concatenated strings
        combined = f"{problem}_{reasoning_process}"
        return str(hash(combined))
    
    def _get_cached_score(self, problem: str, reasoning_process: str) -> Optional[Dict[str, Any]]:
        """Get cached score if available."""
        if not self.enable_caching:
            return None
            
        cache_key = self._create_cache_key(problem, reasoning_process)
        return self.score_cache.get(cache_key, None)
    
    def _cache_score(self, problem: str, reasoning_process: str, score: int, justification: str):
        """Cache the score for future use."""
        if not self.enable_caching:
            return
            
        cache_key = self._create_cache_key(problem, reasoning_process)
        self.score_cache[cache_key] = {
            'problem': problem,
            'reasoning_process': reasoning_process,
            'score': score,
            'justification': justification,
            'timestamp': time.time()
        }
    
    def save_cache(self):
        # Save cache to file periodically
        # if self.cache_file and len(self.score_cache) % 100 == 0:
        try:
            with open(self.cache_file, 'w') as f:
                json.dump(self.score_cache, f, indent=2)
        except Exception as e:
            logging.warning(f"Failed to save cache file: {e}")
    
    def score_single(self, problem: str, reasoning_process: str) -> Dict[str, Any]:
        """
        Score a single reasoning process.
        
        Args:
            problem: The original problem text
            reasoning_process: The model's reasoning process
            
        Returns:
            Dictionary containing score, justification, and metadata
        """
        # Check cache first
        cached_result = self._get_cached_score(problem, reasoning_process)
        if cached_result:
            return {
                'score': cached_result['score'],
                'justification': cached_result['justification'],
                'cached': True,
                'model': self.model_name
            }
        
        # If no model available, return default score
        if not self.model:
            raise ValueError("OpenAI client not available. CoT quality scoring disabled.")
            # return {
            #     'score': 3,  # Default middle score
            #     'justification': 'CoT scoring disabled - no model available',
            #     'cached': False,
            #     'model': 'none'
            # }
        
        try:
            # Create prompt
            prompt = self.prompt_template.format(
                problem=problem,
                reasoning_process=reasoning_process
            )

            
            # Query the model
            response = self.model.query(prompt, max_tokens=15000, temperature=1.0)
            
            # Parse response
            score, justification = self._parse_response(response)
            
            if score is not None:
                #print prompt with probability of 0.001
                if random.random() < 1e-3:
                    print(f"========== Annotation Prompt==========\n{prompt}")
                    print(f"========== Response==========\n{score}\n{justification}")
                    print("========== End ==========")

                # Cache the result
                self._cache_score(problem, reasoning_process, score, justification)
                
                return {
                    'problem': problem,
                    'reasoning_process': reasoning_process,
                    'score': score,
                    'justification': justification,
                    'model': self.model_name
                }
            else:
                # Parsing failed, return default
                return {
                    'problem': problem,
                    'reasoning_process': reasoning_process,
                    'score': None,
                    'justification': f'Failed to parse response: {response[:100]}...',
                    'model': self.model_name,
                    'error': 'parsing_failed'
                }
                
        except Exception as e:
            logging.error(f"Error scoring CoT quality: {e}")
            return {
                'problem': problem,
                'reasoning_process': reasoning_process,
                'score': None,
                'justification': f'Error during scoring: {str(e)}',
                'model': self.model_name,
                'error': 'api_error'
            }
    
    def is_low_quality(self, score: int) -> bool:
        """
        Check if the score is low quality.
        """
        if self.prompt_template == ERROR_TYPE_PROMPT:
            # TODO use low_quality_error_type for strict string matching (conservative)
            # lowercase and strip
            score_ls = [item.lower().strip() for item in score]
            return any(item in self._low_quality_error_type for item in score_ls)
        else:
            return score < self._low_quality_threshold
    
    def score_batch(self, 
                    problems: List[str], 
                    reasoning_processes: List[str],
                    show_progress: bool = True) -> List[Dict[str, Any]]:
        """
        Score multiple reasoning processes in parallel.
        
        Args:
            problems: List of problem texts
            reasoning_processes: List of reasoning processes
            show_progress: Whether to show progress bar
            
        Returns:
            List of scoring results
        """
        if len(problems) != len(reasoning_processes):
            raise ValueError("Problems and reasoning_processes must have the same length")
        
        results = []
        
        # Process in parallel using ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all tasks
            future_to_index = {
                executor.submit(self.score_single, problem, reasoning): i
                for i, (problem, reasoning) in enumerate(zip(problems, reasoning_processes))
            }
            
            # Collect results as they complete
            for future in as_completed(future_to_index):
                index = future_to_index[future]
                try:
                    result = future.result()
                    results.append((index, result))
                except Exception as e:
                    logging.error(f"Error processing item {index}: {e}")
                    results.append((index, {
                        'problem': problems[index],
                        'reasoning_process': reasoning_processes[index],
                        'score': None,
                        'justification': f'Error: {str(e)}',
                        'model': self.model_name,
                        'error': 'processing_error'
                    }))
        
        # Sort results by original index
        results.sort(key=lambda x: x[0])
        return [result[1] for result in results]
    
    def save_cache(self):
        """Save the current cache to file."""
        if self.enable_caching and self.cache_file:
            try:
                with open(self.cache_file, 'w') as f:
                    json.dump(self.score_cache, f, indent=2)
                print(f"Saved {len(self.score_cache)} cached scores to {self.cache_file}")
            except Exception as e:
                logging.error(f"Failed to save cache: {e}")
    
    def get_cost(self) -> float:
        """Get the total cost of API calls."""
        if hasattr(self.model, 'get_cost'):
            return self.model.get_cost()
        return 0.0


def compute_cot_quality_score(problem: str, 
                             reasoning_process: str, 
                             model_name: str = "gpt-5-nano",
                             **kwargs) -> Dict[str, Any]:
    """
    Convenience function to compute CoT quality score for a single item.
    
    Args:
        problem: The original problem text
        reasoning_process: The model's reasoning process
        model_name: Name of the LLM model to use
        **kwargs: Additional arguments for CoTQualityScorer
        
    Returns:
        Dictionary containing score and metadata
    """
    scorer = CoTQualityScorer(model_name=model_name, **kwargs)
    return scorer.score_single(problem, reasoning_process)


def compute_cot_quality_score_batch(problems: List[str],
                                   reasoning_processes: List[str],
                                   model_name: str = "gpt-5-nano",
                                   **kwargs) -> List[Dict[str, Any]]:
    """
    Convenience function to compute CoT quality scores for multiple items.
    
    Args:
        problems: List of problem texts
        reasoning_processes: List of reasoning processes
        model_name: Name of the LLM model to use
        **kwargs: Additional arguments for CoTQualityScorer
        
    Returns:
        List of scoring results
    """
    scorer = CoTQualityScorer(model_name=model_name, **kwargs)
    return scorer.score_batch(problems, reasoning_processes)
