import os
import re
from typing import Dict, Tuple, Optional
import math
from sympy import Rational
import numpy as np
import sys
import re
import string
from collections import Counter, defaultdict
import pickle
from pathlib import Path
import json
from tqdm import tqdm
import openai
import time

# For General ORM to verify correctness of LLM's solution. We disable this by default, as it doesn't help much.
GENERAL_ORM_PROMPT = """You are an expert in verifying if two answers are the same.
Your input is a problem and two answers, Answer 1 and Answer 2. You need to check if they are equivalent.
Your task is to determine if two answers are equivalent, without attempting to solve the original problem.
Compare the answers to verify they represent identical values or meaning, even when written in different forms or notations.

Your output must follow the following format:
1) Provide an explanation for why the answers are equivalent or not.
2) Then provide your final answer in the form of: [[YES]] or [[NO]]
"""

ORM_USER_TEMPLATE = """
Problem: {problem}
Answer 1: {answer_1}
Answer 2: {answer_2}
"""

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    ZERO_METRIC = (0, 0, 0)

    if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC
    if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return ZERO_METRIC
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall

def sub_em(prediction, ground_truth):
    ground_truth = normalize_answer(ground_truth)
    prediction = normalize_answer(prediction) 
    return (ground_truth in prediction) or (prediction in ground_truth)

def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def update_answer(metrics, prediction, gold):
    em = exact_match_score(prediction, gold)
    subem = sub_em(prediction, gold)

    f1, prec, recall = f1_score(prediction, gold)
    metrics['sub_em'] += subem
    metrics['em'] += float(em)
    metrics['f1'] += f1
    metrics['prec'] += prec
    metrics['recall'] += recall
    metrics['total_num'] += 1
    return em, prec, recall

def calc_metrics(predictions, goldens):
    assert len(predictions) == len(goldens)
    metrics = {'f1': 0, 'prec': 0, 'recall': 0, 'em': 0, 'sub_em': 0, 'total_num': 0}
    for pred, gold in zip(predictions, goldens):
        update_answer(metrics, pred, gold)
    for k, _ in metrics.items():
        if k == 'total_num':
            continue
        metrics[k] = round((metrics[k]/metrics['total_num']), 2)
    return metrics

def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
    """Extracts the final answer from the model's response string.
    
    Args:
        solution_str: Raw response string from the language model
    Returns:
        Tuple containing (extracted_answer, processed_string)
    """
    
    # Extract final answer using XML-style tags
    if "</think>" not in solution_str:
        print("[Error] No valid answer tags found")
        return None, solution_str
    
    final_answer = solution_str.split("</think>")[-1].strip()
    return final_answer, solution_str

def parse_model_answer(response: str) -> Optional[str]:
    """Parses the final answer from the model's response text.
    
    Args:
        response: Text extracted from the model's response
        
    Returns:
        The final answer as a numeric value (string), or None if not found
    """
    # Remove any asterisks or other unwanted characters
    response = response.replace('*', '')
    
    from .longcontext_qa import last_boxed_only_string, remove_boxed
    boxed_answer = last_boxed_only_string(response)
    if boxed_answer is not None:
        ans = remove_boxed(boxed_answer)
    else:
        ans = None

    return ans
    
def validate_response_structure(processed_str: str) -> bool:
    """Performs comprehensive validation of response structure.
    
    Args:
        processed_str: Processed response string from the model
        
    Returns:
        Boolean indicating whether all formatting requirements are met
    """
    print("\n[Structure Validation]")
    validation_passed = True

    # Check required tags
    tags = {
        'think_start': ('<think>', 1),
        'think_end': ('</think>', 1)
    }

    positions = {}
    for tag_name, (tag_str, expected_count) in tags.items():
        count = processed_str.count(tag_str)
        positions[tag_name] = pos = processed_str.find(tag_str)
        
        print(f"  {tag_str}: count={count}, position={pos}")
        
        if count != expected_count:
            print(f"  [Error] {tag_str} appears {count} times (expected {expected_count})")
            validation_passed = False

    # Verify tag order
    if (positions['think_start'] > positions['think_end']):
        print("  [Error] Incorrect tag order: Expected <think>...</think>")
        validation_passed = False
    else:
        print("  Tag sequence validation passed")

    return validation_passed

def get_azure_client():
    import os 
    import re
    from openai import OpenAI, AzureOpenAI
    from tqdm import tqdm
    from datasets import load_dataset
    from openai import AzureOpenAI
    from azure.identity import (
        DefaultAzureCredential,
        ChainedTokenCredential,
        AzureCliCredential,
        get_bearer_token_provider,
    )

    scope = os.getenv("scope")
    credential = get_bearer_token_provider(
        ChainedTokenCredential(
            AzureCliCredential(),
            DefaultAzureCredential(
                exclude_cli_credential=True,
                exclude_environment_credential=True,
                exclude_shared_token_cache_credential=True,
                exclude_developer_cli_credential=True,
                exclude_powershell_credential=True,
                exclude_interactive_browser_credential=True,
                exclude_visual_studio_code_credentials=True,
                managed_identity_client_id=os.environ.get("DEFAULT_IDENTITY_CLIENT_ID"),
            ),
        ),
        scope,
    )

    api_version = os.getenv("api_version")  # Ensure this is a valid API version
    model_name = os.getenv("model_name")  # Ensure this is a valid model name
    model_version = os.getenv("model_version")  # Ensure this is a valid model version
    deployment_name = re.sub(
        r"[^a-zA-Z0-9-_]", "", f"{model_name}_{model_version}"
    )  
    instance = os.getenv("instance")
    endpoint = os.getenv("endpoint")

    client = AzureOpenAI(
        azure_endpoint=endpoint,
        azure_ad_token_provider=credential,
        api_version=api_version,
    )
    model = f"{model_name}_{model_version}"
    return client, model

def get_openai_client():
    import openai
    key = os.getenv("OPENAI_API_KEY")
    baseurl = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
    client = openai.OpenAI(api_key=key, api_base=baseurl)
    model = os.getenv("OPENAI_API_MODEL", "gpt-4o")
    return client, model

def call_oai_rm_llm(
    prompt: str,
    system_prompt: str,
    n: int = 1,
    temperature: float = 1.0,
    retry_count: int = 10000,          # maximum “true” retries
):
    """Call an OpenAI-compatible chat endpoint with robust retry logic.

    429 (rate-limit) and 403 (key / permission) errors do **not** consume
    `retry_count`; all other exceptions do.

    Returns
    -------
    str | list[str]
        One string if n == 1, otherwise a list of n strings.
    """
    # client, model_id = get_azure_client()     # helper that returns (client, model_id)
    if os.getenv("USE_AZURE", "0") == "1":
        client, model_id = get_azure_client()
    else:
        client, model_id = get_openai_client()

    attempts = 0          # counts only “real” retries
    backoff  = 1          # seconds to wait after 429 / 403 (exponential ≤ 64 s)

    while True:
        try:
            response = client.chat.completions.create(
                model=model_id,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user",   "content": prompt},
                ],
                temperature=temperature,
                n=n,
            )
            break   # success → exit loop

        except Exception as exc:
            msg = str(exc)

            # ----------- errors that do NOT consume retry_count -----------
            if "429" in msg:
                print(f"[429] rate limit; retrying in {backoff}s")
                time.sleep(backoff)
                backoff = min(backoff * 2, 64)    # exponential back-off
                continue                          # attempts stays the same

            if "403" in msg:
                print(f"[403] auth / quota issue; retrying in {backoff}s")
                time.sleep(backoff)
                client, model_id = get_azure_client()  # refresh key / endpoint
                backoff = min(backoff * 2, 64)
                continue                          # attempts stays the same
            if "lifetime" in msg and "expired" in msg:
                print(f"[azure cli] token lifetime expired; refreshing")
                client, model_id = get_azure_client()
                backoff = min(backoff * 2, 64)
                continue

            # ----------- errors that DO consume retry_count -----------
            attempts += 1
            print(f"[{attempts}/{retry_count}] other exception: {exc}")

            if attempts >= retry_count:
                return []

    # ---------------- return payload ----------------
    if n == 1:
        return response.choices[0].message.content
    return [choice.message.content for choice in response.choices]

def call_reward_model(problem: str, model_answer: str, ground_truth: str):
    # start_index = problem.index("</text>")
    # end_index = problem.index("Format your response as follows:")
    # question = problem[start_index: end_index].replace("</text>", "").strip()
    question = problem.strip()
    orm_response = call_oai_rm_llm(
        system_prompt=GENERAL_ORM_PROMPT,
        prompt=ORM_USER_TEMPLATE.format(problem=question, answer_1=model_answer, answer_2=ground_truth),
        temperature=0.0,
        retry_count=10,
    )
    if "YES" in orm_response:
        return 1.0
    else:
        return 0.0

def compute_score(solution_str: str, 
                 ground_truth: Dict[str, str],
                 prompt_str: str,
                 format_reward: float = 0.0,
                 answer_reward: float = 1.0) :
    """Computes comprehensive score for model response.
    
    Args:
        solution_str: Raw model response string
        ground_truth: Dictionary containing ground truth data
        format_reward: Points awarded/deducted for format correctness
        answer_reward: Points awarded/deducted for answer correctness
        
    Returns:
        Total score (sum of format and answer rewards)
    """
    print("\n" + "="*80)
    print(" Processing New Sample ".center(80, '='))
    
    # Extract model answer
    answer_text, processed_str = extract_solution(solution_str)
    print(f"\n[Model Response]\n{processed_str}")
    # Validate answer content
    answer_score = 0
    if answer_text:
        try:
            pred_status = parse_model_answer(answer_text)
        except Exception as e:
            print(f"[Error] Failed to parse model answer {answer_text}: {e}")
            pred_status = None
        # gt_status = parse_model_answer(ground_truth)
        if isinstance(ground_truth, str):
            gt_status = ground_truth
        else:
            gt_status = ground_truth[0]
        
        if pred_status:
            print(f"\n[Content Validation]")
            print(f"  Expected: {gt_status}")
            print(f"  Predicted: {pred_status}")
            print(f" Prompt Input Question: {prompt_str}")
            metrics = calc_metrics([pred_status], [gt_status])
            metric = metrics['sub_em']
            if metric < 1.0 and os.getenv('LLM_JUDGE') == "1":
                rm_metric = call_reward_model(prompt_str, pred_status, gt_status)
                print(f"  RM Score: {rm_metric}")
                metric = max(metric, rm_metric)
            answer_score = metric
            print(f"  Answer Score: {answer_score}")
        else:
            answer_score = 0.0
            print( "Fail to parse answer")
    else:
        print("\n[Content Validation] Skipped due to format errors or missing answer")
    print("\n" + "-"*80)
    print(f" Final Score ".center(80, '-'))
    print(f"  Answer: {answer_score}")
    print("="*80 + "\n")
    return {
        "score": answer_score,
        "acc": answer_score == 1.0,
        "pred": answer_text,
    }