"""
formular: reward = a * SemanticAccuracy + b * VerifierScore - c * eFLOPs


- a, b, c dsetermined by empirical tuning to balance accuracy and efficiency.
- SemanticAccuracy by Jaccard Similarity between LLM response and ground truth answer
- VerifierScore from verifier model output
- eFLOPs by selected_model and L_in/L_out 
"""
import re
import numpy as np
from typing import List, Optional, Tuple
from config import MODEL_PARAM_P, MODEL_GQA_RATIO_R, MODEL_KV_SIZE_D, ARITHMETIC_INTENSITY_I

SPLIT_REGEX = re.compile(r'\W+')


def compute_semantic_accuracy(llm_response: Optional[str], answer: Optional[str]) -> float:
    """
    Calculate semantic accuracy using Jaccard Similarity.
    """
    if answer is None:
        return 1.0

    # Use regex for tokenization, more robust than str.split()
    response_tokens = set(SPLIT_REGEX.split((llm_response or "").lower()))
    answer_tokens = set(SPLIT_REGEX.split((answer or "").lower()))

    # Remove empty strings
    response_tokens.discard('')
    answer_tokens.discard('')

    if not answer_tokens:
        return 0.0

    intersection = response_tokens & answer_tokens
    union = response_tokens | answer_tokens

    if not union:
        return 0.0

    return float(len(intersection) / len(union))


def _calculate_eflops_for_segment(P: float, r: float, D: int, I: float, N: int, seg_in: int, seg_out: int) -> float:
    """
    Calculate eFLOPs for a single segment.
    """
    seg_in = max(0, seg_in)
    seg_out = max(0, seg_out)
    
    # Calculate various costs
    # Linear layer computation
    comp_linear = 2 * N * P * seg_out
    
    # Attention mechanism computation
    comp_attention = (2 * r * N * seg_in * D * seg_out) + (r * N * D * (seg_out ** 2))
    
    # KV Cache access cost
    mem_kv_access_cost = (2 * I * seg_in * D * seg_out) + (I * N * D * (seg_out ** 2))
    
    # Parameter access cost
    mem_param_access_cost = 2 * I * N * P * seg_out
    
    return float(comp_linear + comp_attention + mem_kv_access_cost + mem_param_access_cost)


def compute_eflops(selected_model: str, L_in: List[int], L_out: List[int], qp: int = 1, cp: int = 1, verifier_model: Optional[str] = None) -> float:
    """
    Calculate total eFLOPs.
    Consider the impact of QP (Reasoning Trees) and CP (Candidate Paths).
    N = min(QP * CP, 64)
    """
    # Get generator model parameters
    P_gen = MODEL_PARAM_P.get(selected_model)
    if P_gen is None:
        # If model not found, try to find the closest match or use default to avoid crash
        # Simple handling here: log a warning and use default (assumed smallest model)
        print(f"Warning: Model '{selected_model}' not found in MODEL_PARAM_P. Using default (qwen3-0.6b).")
        P_gen = MODEL_PARAM_P.get("qwen3-0.6b", 0.6e9)
        r_gen = MODEL_GQA_RATIO_R.get("qwen3-0.6b", 0.5)
        D_gen = MODEL_KV_SIZE_D.get("qwen3-0.6b", 57344)
    else:
        r_gen = MODEL_GQA_RATIO_R.get(selected_model, 0.5)
        D_gen = MODEL_KV_SIZE_D.get(selected_model, 128)

    # Get verifier model parameters (if any)
    P_ver = r_ver = D_ver = None
    if verifier_model:
        P_ver = MODEL_PARAM_P.get(verifier_model)
        if P_ver is None:
             print(f"Warning: Verifier model '{verifier_model}' not found. Using default.")
             P_ver = MODEL_PARAM_P.get("qwen3-0.6b", 0.6e9)
             r_ver = MODEL_GQA_RATIO_R.get("qwen3-0.6b", 0.5)
             D_ver = MODEL_KV_SIZE_D.get("qwen3-0.6b", 57344)
        else:
            r_ver = MODEL_GQA_RATIO_R.get(verifier_model, 0.5)
            D_ver = MODEL_KV_SIZE_D.get(verifier_model, 128)

    I = ARITHMETIC_INTENSITY_I
    
    # Calculate parallelism N
    N = max(1, min(qp * cp, 64))

    # According to user strategy: when parallelism is 1, do not calculate Verifier's eFLOPs
    if N == 1:
        verifier_model = None

    # Ensure L_in and L_out are lists and have the same length
    L_in = list(L_in or [])
    L_out = list(L_out or [])
    
    # Pad lengths to avoid crashes
    if len(L_in) != len(L_out):
        min_len = min(len(L_in), len(L_out))
        L_in = L_in[:min_len]
        L_out = L_out[:min_len]
        # Or raise a warning instead of an error
        # raise ValueError("L_in and L_out lists must have the same length")

    total_cost = 0.0
    for seg_in_i, seg_out_i in zip(L_in, L_out):
        total_cost += _calculate_eflops_for_segment(P=P_gen, r=r_gen, D=D_gen, I=I, N=N, seg_in=seg_in_i, seg_out=seg_out_i)
        
        if verifier_model:
            l_in_ver = seg_in_i + seg_out_i
            l_out_ver = 1
            # Verifier usually runs once, or once per candidate path?
            # Assume verifier runs for each generated candidate path, then N_ver = N
            total_cost += _calculate_eflops_for_segment(P=P_ver, r=r_ver, D=D_ver, I=I, N=N, seg_in=l_in_ver, seg_out=l_out_ver)

    return float(total_cost)


def reward_function(user_query: str, model: str, verifier_score: float, L_in: List[int], L_out: List[int], is_correct: float, qp: int = 1, cp: int = 1, verifier_model: Optional[str] = None) -> dict:
    """
    Main entry for reward calculation.
    Returns a dictionary containing total reward and detailed components.
    """
    try:
        # Use is_correct directly
        semantic_accuracy = float(is_correct)
        
        # Calculate eFLOPs
        eFLOPs = compute_eflops(model, L_in, L_out, qp=qp, cp=cp, verifier_model=verifier_model)
        
        # Normalize eFLOPs
        # metric represents eFLOPs per token per path (roughly)
        # New normalization parameters derived from detailed analysis:
        # Min (0.6B, N=1, Len=1): Log10 ~ 11.4027 -> 0.0
        # Max (32B, N=64, Len=8k): Log10 ~ 16.7673 -> 1.0 (Using scale 5.3646)
        total_output_len = sum(L_out) if L_out else 1
        parallelism = max(1, qp * cp)
        
        # This metric penalizes Model Size and Parallelism, but is relatively invariant to Length
        metric = eFLOPs / (max(1, total_output_len) / parallelism)
        
        log_val = np.log10(max(1.0, metric))
        
        # Apply normalization map [11.40, 16.77] -> [0, 1]
        normalized_eflops = (log_val - 11.4027) / 5.3646
        normalized_eflops = max(0.0, min(1.0, normalized_eflops))

        # Adjusted weights to equalize 0.6B and 32B rewards based on test150_1222 logs
        # And normalized to span full [0, 1] range.
        # Ratio w_cost / (w_acc + w_ver) = 0.7257
        # Constraints: 
        # 1. 2*w_acc + w_cost = 1 (Max reward = 1)
        # 2. w_acc = w_ver
        # Result: w_acc = w_ver = 0.3669, w_cost = 0.2662, bias = 0.2662
        reward = 0.3669 * semantic_accuracy + 0.3669 * float(verifier_score) - 0.2662 * normalized_eflops + 0.2662
        
        print(f"[Reward Calc] Total: {reward:.4f} | Acc={semantic_accuracy:.2f} (w=0.3669), Verifier={verifier_score:.2f} (w=0.3669), eFLOPs={eFLOPs:.2e} (norm={normalized_eflops:.4f}, w=0.2662)")
        
        return {
            'total_reward': float(reward),
            'accuracy': semantic_accuracy,
            'verifier_score': float(verifier_score),
            'eflops': eFLOPs,
            'normalized_eflops': normalized_eflops,
        }
        
    except Exception as e:
        print(f"Error calculating reward: {e}")
        # Fallback strategy: use verifier_score only
        try:
            fallback_reward = float(verifier_score) * 0.5
            return {
                'total_reward': fallback_reward,
                'error': str(e)
            }
        except Exception:
            return {
                'total_reward': 0.0,
                'error': str(e)
            }
