# verl/verl/utils/reward_score/factor.py
import json
import re
import numpy as np
import requests
import logging

logger = logging.getLogger(__name__)

def compute_score(solution_str, ground_truth=None, method="flexible", format_score=0.0, score=1.0, **kwargs):
    """
    score the solution string from the agent
    """
    response = solution_str
    params = {}
    structured_reward = 0.0  # structured reward (0~0.3)
    # semantic_reward = 0.0     # semantic reward (0~0.2)
    # backtest_reward = 0.0     # backtest reward (0~0.5)

    # === Stage 1: Structured Reward (Parse <tool_call>) ===
    tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', response, re.DOTALL)
    if tool_call_match:
        try:
            tool_call = json.loads(tool_call_match.group(1))
            params = tool_call.get('parameters', {})
            structured_reward = 1  #  
        except json.JSONDecodeError:
            # fix common JSON issues
            try:
                fixed_str = tool_call_match.group(1).replace("'", '"').replace("{{", "{").replace("}}", "}")
                tool_call = json.loads(fixed_str)
                params = tool_call.get('parameters', {})
                structured_reward = 0.2  #  
            except:
                pass

    # # === Stage 2: Semantic Reward (Check Factor Parameters) ===
    # factor_name = params.get('factor_name', '')
    # factor_expr = params.get('factor_expr', '')
    
    # # check validity of factor name and expression
    # if factor_name and factor_expr:
    #     semantic_reward = 0.2
    #     #  
    #     if any(keyword in factor_expr for keyword in ["TS_", "RANK(", "FILTER("]):
    #         semantic_reward += 0.1  # 

    # # === Stage 3: Backtest Reward (Optional) ===
    # if kwargs.get("enable_backtest", True) and factor_name and factor_expr:
    #     test_request = {
    #         "exprs": {factor_name: factor_expr},
    #         "date_split": {  #  
    #             "train_start_time": "2018-01-01",
    #             "train_end_time": "2020-12-31"
    #         }
    #     }
    #     try:
    #         api_response = requests.post("http://localhost:8001/backtest", json=test_request, timeout=10)
    #         if api_response.status_code == 200:
    #             result = api_response.json()
    #             metric_value = result.get('data', {}).get('metrics', {}).get('Information_Ratio_with_cost', 0.0)
    #             backtest_reward = min(0.5, max(0, metric_value / 2.0))  #   0~0.5
    #     except Exception as e:
    #         logger.warning(f"Failed: {e}")

    # === Total Reward ===
    total_reward = structured_reward # + semantic_reward + backtest_reward

    logger.info(f"Total Reward: {total_reward}")
    return min(1.0, total_reward)  #   1.0