import sys
import os
import re
import random
import json

sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))))
from utils.exact_match_utils import judge_exact_match



def extract_solution(solution_str: str, do_print: bool):
    """Extract the equation from the solution string."""
    # Regular expression to find the last occurrence of <answer>...</answer>
    answer_pattern = r'<answer>(.*?)</answer>'
    matches = re.findall(answer_pattern, solution_str, re.DOTALL)  # Use re.DOTALL to match multiline content

    if matches:
        return matches[-1].strip()
    else:
        if do_print:
            print("[Error] No valid answer tags found")
        return None
        

def validate_response_structure(processed_str: str, do_print: bool) -> bool:
    """Performs comprehensive validation of response structure.
    
    Args:
        processed_str: Processed response string from the model
        
    Returns:
        Boolean indicating whether all formatting requirements are met
    """
    if do_print:
        print("\n[Structure Validation]")
    validation_passed = True

    # processed_str = '<think> </think>' + processed_str
    
    # Check required tags
    tags = {
        'think_start': ('<think>', 1),
        'think_end': ('</think>', 1),
        'answer_start': ('<answer>', 1),
        'answer_end': ('</answer>', 1)
    }

    positions = {}
    for tag_name, (tag_str, expected_count) in tags.items():
        count = processed_str.count(tag_str)
        positions[tag_name] = pos = processed_str.find(tag_str)
        
        if do_print:
            print(f"  {tag_str}: count={count}, position={pos}")
        
        if count != expected_count:
            if do_print:
                print(f"  [Error] {tag_str} appears {count} times (expected {expected_count})")
            validation_passed = False

    # Verify tag order
    if (positions['think_start'] > positions['think_end'] or
        positions['think_end'] > positions['answer_start'] or
        positions['answer_start'] > positions['answer_end']):
        if do_print:
            print("  [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
        validation_passed = False
    else:
        if do_print:
            print("  Tag sequence validation passed")
    
    return validation_passed


def check_json_format(json_str, do_print=False):
    """Check if the given string is a valid JSON and follows the expected structure."""
    try:
        if not json_str:
            if do_print:
                print("[Error] Empty JSON string")
            return False
        
        data = json.loads(json_str)
        
        # Required keys
        required_keys = {"answer"}
        if not all(key in data for key in required_keys):
            if do_print:
                print("[Error] Missing required keys in JSON")
            return False

        return True
    except Exception as e:
        if do_print:
            print("[Error] JSON decoding failed")
        return False


def calculate_answer_score(pred_answer, gold_answer, table_data, do_print=False):
    """Calculate answer score based on final_prediction idx."""
    try:
        # format into list for judge_exact_match
        pred_results = [str(result) for result in str(pred_answer).split(', ')]
        gold_results = gold_answer

        answer_score = 1 if judge_exact_match(pred_results, gold_results) else 0
        
    except Exception as e:
        if do_print:
            print(f"[Error] Error in answer parsing: {e}")
        pred_results = []
        answer_score = 0
    
    return answer_score, pred_results


def compute_score(solution_str, ground_truth, extra_info):
    """The scoring function for countdown task.
    
    Args:
        solution_str: the solution text
        ground_truth: dictionary containing target number and available numbers
        extra_info: extra information
    """
    table_data = extra_info['table']
    prompt_str = extra_info['prompt_str']

    
    do_print = random.randint(1, 16) == 1
    # do_print = False


    solution_str = '<think>\n' + solution_str
    answer_text = extract_solution(solution_str, do_print)

    # Validate response structure
    response_format_correct = validate_response_structure(solution_str, do_print)
    json_format_correct = check_json_format(answer_text, do_print)
    format_correct = response_format_correct and json_format_correct
    
    format_score = 0.1 if format_correct else -2
    
    answer_score = 0

    if format_correct and answer_text:
        pred_answer = json.loads(answer_text)['answer']
        gold_answer = ground_truth['answer']
        answer_score, pred_results = calculate_answer_score(pred_answer, gold_answer, table_data, do_print)
    else:
        pred_results = []

    if answer_score > 0:
        total_score = format_score + answer_score
    else:
        if format_score > 0:
            total_score = 0
        else:
            total_score = format_score

    # avoid error from mix precision
    format_score = round(format_score, 1)
    answer_score = round(answer_score, 1)
    total_score = round(total_score, 1)

    if do_print:
        print(f"--------------------------------")
        print(f"Prompt: {prompt_str}")
        print(f"Solution string: {solution_str}")
        print(f"--------------------------------")
        print(f"Ground truth: {ground_truth}")
        print(f"Answer text: {answer_text}")
        print(f"Cleaned answer text: {pred_results}")
        print(f"--------------------------------")
        print(f"Final Score:")
        print(f"    Format: {format_score}")
        print(f"    Answer: {answer_score}")
        print(f"    Total: {total_score}")
        print(f"--------------------------------")
        print("="*80 + "\n")

    return total_score



if __name__ == '__main__':
    solution_str = """<|im_start|>assistant:  <answer>{"query": "Microstructural development of human"}</answer>
"""
    ground_truth = {'target': '4983'}
    scores = compute_score(solution_str, ground_truth)
    print(scores)