import re
from typing import Dict, Tuple, Optional
import openai
from openai import OpenAI
import debugpy
import time
import random

def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
    """Extracts the final answer from the model's response string.
    
    Args:
        solution_str: Raw response string from the language model
        
    Returns:
        Tuple containing (extracted_answer, processed_string)
    """
    # Split response to isolate assistant output
    if "Assistant:" in solution_str:
        processed_str = solution_str.split("Assistant:", 1)[1]
    elif "<|im_start|>assistant" in solution_str:
        processed_str = solution_str.split("<|im_start|>assistant", 1)[1]
    else:
        # print("[Error] Failed to locate model response header")
        # return None, solution_str
        processed_str = solution_str

    # Extract final answer using XML-style tags
    answer_pattern = r'<answer>(.*?)</answer>'
    matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))
    
    if not matches:
        print("[Error] No valid answer tags found")
        return None, processed_str
        
    final_answer = matches[-1].group(1).strip()
    return final_answer, processed_str

def validate_response_structure(processed_str: str) -> bool:
    """Performs comprehensive validation of response structure.
    
    Args:
        processed_str: Processed response string from the model
        
    Returns:
        Boolean indicating whether all formatting requirements are met
    """
    print("\n[Structure Validation]")
    validation_passed = True
    
    ### System 1 Cognition

    
    #### Adaptive Cognition
    '''
    pattern1 = r'<answer>(.*?)</answer>'
    pattern2 = r'<social context understanding>(.*?)</social context understanding>\s*<answer>(.*?)</answer>'
    pattern3 = r'<think>(.*?)</think>\s*<answer>(.*?)</answer>'
    
    
    match1 = re.search(pattern1, processed_str, re.DOTALL)
    match2 = re.search(pattern2, processed_str, re.DOTALL)
    match3 = re.search(pattern3, processed_str, re.DOTALL)
    
    
    answer_start_count = processed_str.count("<answer>")
    answer_end_count = processed_str.count("</answer>")
    social_start_count = processed_str.count("<social context understanding>")
    social_end_count = processed_str.count("</social context understanding>")
    think_start_count = processed_str.count("<think>")
    think_end_count = processed_str.count("</think>")

    if match1 and answer_start_count == 1 and answer_end_count == 1 and \
       social_start_count == 0 and social_end_count == 0 and \
       think_start_count == 0 and think_end_count == 0:
        answer_start_pos = processed_str.find("<answer>")
        answer_end_pos = processed_str.find("</answer>")
        if answer_start_pos < answer_end_pos:
           print("Pattern 1: <answer> </answer>")
        else:
            validation_passed = False
    
    elif match2 and social_start_count == 1 and social_end_count == 1 and \
         answer_start_count == 1 and answer_end_count == 1 and \
         think_start_count == 0 and think_end_count == 0:
        social_start_pos = processed_str.find("<social context understanding>")
        social_end_pos = processed_str.find("</social context understanding>")
        answer_start_pos = processed_str.find("<answer>")
        answer_end_pos = processed_str.find("</answer>")
        
        if social_start_pos < social_end_pos < answer_start_pos < answer_end_pos:
            print("Pattern 2: <social context understanding> </social context understanding> <answer> </answer>")
        else:
            validation_passed = False
    
    elif match3 and think_start_count == 1 and think_end_count == 1 and \
         answer_start_count == 1 and answer_end_count == 1 and \
         social_start_count == 0 and social_end_count == 0:
        think_start_pos = processed_str.find("<think>")
        think_end_pos = processed_str.find("</think>")
        answer_start_pos = processed_str.find("<answer>")
        answer_end_pos = processed_str.find("</answer>")
        
        if think_start_pos < think_end_pos < answer_start_pos < answer_end_pos:
            print("Pattern 3: <think> </think> <answer> </answer>")
        else:
            validation_passed = False
    
    else:
        validation_passed = False
    '''
    
    ###### Check required tags System 2 Cognition
    
    tags = {
        'think_start': ('<think>', 1),
        'think_end': ('</think>', 1),
        'answer_start': ('<answer>', 1),
        'answer_end': ('</answer>', 1)
    }

    positions = {}
    for tag_name, (tag_str, expected_count) in tags.items():
        count = processed_str.count(tag_str)
        positions[tag_name] = pos = processed_str.find(tag_str)
        
        print(f"  {tag_str}: count={count}, position={pos}")
        
        if count != expected_count:
            print(f"  [Error] {tag_str} appears {count} times (expected {expected_count})")
            validation_passed = False

    # Verify tag order
    if (positions['think_start'] > positions['think_end'] or
        positions['think_end'] > positions['answer_start'] or
        positions['answer_start'] > positions['answer_end']):
        print("  [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
        validation_passed = False
    else:
        print("  Tag sequence validation passed")
    

    return validation_passed

def getOutput(prompt:str, max_retries=30) -> str:

    for i in range(max_retries):
        try:
            client = OpenAI(
                api_key="",
                base_url="",
            )

            completion = client.chat.completions.create(
                model='deepseek-v3',  
                messages=[
                    {'role': 'system', 'content': 'You are a helpful assistant.'},
                    {'role': 'user', 'content': prompt}
                ],
                temperature = 0
            )
            print(completion.choices[0].message.content)
            output = completion.choices[0].message.content
            #tokens = completion.usage
            return output
        except Exception as e:
            if i == max_retries - 1:  # If this was the last attempt
                raise  # re-throw the last exception
            else:
                # Wait for a bit before retrying and increase the delay each time
                sleep_time = (2 ** i) + random.random()  # Exponential backoff with full jitter
                time.sleep(sleep_time) 


def compute_score(solution_str: str, 
                 ground_truth: str,
                 format_reward: int = 1,
                 answer_reward: float = 1.0) :
    """Computes comprehensive score for model response.
    
    Args:
        solution_str: Raw model response string
        ground_truth: Dictionary containing ground truth data
        format_reward: Points awarded/deducted for format correctness
        answer_reward: Points awarded/deducted for answer correctness
        
    Returns:
        Total score (sum of format and answer rewards)
    """
    print("\n" + "="*80)
    print(" Testing Processing New Sample ".center(80, '='))
    

    answer_text, processed_str = extract_solution(solution_str)
    print(f"\n[Model Response]\n{processed_str}")

    # Validate response structure
    format_correct = validate_response_structure(processed_str)
    format_score = format_reward if format_correct else -abs(format_reward)
    print(f"\n  Format validation: {'PASS' if format_correct else 'FAIL'}")
    print(f"  Format score: {format_score}")

    if format_score > 0:
        format_score_scaled = 1
    else:
        format_score_scaled = 0
    # Validate answer content
    answer_score = 0
    #print(answer_text)
    #print(ground_truth)
    if format_correct and answer_text:

        print("-----------------------")
        print("Answer Text\n")
        print(answer_text)
        print("Ground Truth\n")
        print(ground_truth)
        print("-----------------------")

        answer_text = answer_text.replace("A: ", "").replace("B: ", "").replace("C: ", "").replace("D: ", "")
        print(answer_text)
        pattern = r"^[A-Z]\. ([\w_]+)"

        if answer_text == ground_truth:
        #if 'True' in judge_result:
            answer_score = 2
            answer_score_scaled = 1
            #print(answer_text)
            #print(ground_truth)
            print("  Content validation: FULL MATCH")
        elif re.match(pattern, answer_text):
            if answer_text.split(" ")[1] == ground_truth:
                answer_score = 2
                answer_score_scaled = 1
            else:
                answer_score = -1.5
                answer_score_scaled = 0
                print("  Content validation: MISMATCH zhengze")
        else:
            answer_score = -1.5
            answer_score_scaled = 0
            print("  Content validation: MISMATCH")
    #else:
        #answer_score = -2
        #answer_score_scaled = 0
        #print( "Fail to parse answer")
    else:
        answer_score = -2
        answer_score_scaled = 0
        print("\n[Content Validation] Skipped due to format errors or missing answer")

    total_score = format_score + answer_score
    print("\n" + "-"*80)
    print(f" Final Score ".center(80, '-'))
    print(f"  Format: {format_score}")
    print(f"  Answer: {answer_score}")
    print(f"  Total: {total_score}")
    print("="*80 + "\n")
    
    output = {
        "score": total_score,
        "extra_info": {
            "outcome_score": answer_score_scaled,
            "format_score": format_score_scaled
        }
    }

    return output


'''
def compute_score(solution_str: str, 
                 ground_truth: str,
                 format_reward: int = 1,
                 answer_reward: float = 1.0) :
    """Computes comprehensive score for model response.
    
    Args:
        solution_str: Raw model response string
        ground_truth: Dictionary containing ground truth data
        format_reward: Points awarded/deducted for format correctness
        answer_reward: Points awarded/deducted for answer correctness
        
    Returns:
        Total score (sum of format and answer rewards)
    """
    print("\n" + "="*80)
    print(" Testing Processing New Sample ".center(80, '='))
    
    # Parse ground truth data 
    #solution_text = ground_truth.get('solution_text_format', '')
    #gt_status = parse_solution_text_format(solution_text)
    #expected_names = list(gt_status.keys())
    #print(f"[Ground Truth] Final identities: {gt_status}")

    # Extract model answer
    answer_text, processed_str = extract_solution(solution_str)
    print(f"\n[Model Response]\n{processed_str}")

    processed_str = processed_str.strip().replace("\n", "").replace("<|im_end|>", "").replace("A: ", "").replace("B: ", "").replace("C: ", "").replace("D: ", "")
    print(f"\n[Model Response]\n{processed_str}")
    # Validate answer content
    answer_score = 0
    #print(answer_text)
    #print(ground_truth)
    if processed_str == ground_truth:
        #pred_status = parse_model_answer(answer_text, expected_names)
        #if pred_status:
            #print(f"\n[Content Validation]")
            #print(f"  Expected: {gt_status}")
            #print(f"  Predicted: {pred_status}")
        
        #judge = f"""
        #The correct answer is:
        #{ground_truth}
        #Model's predict answer is:
        #{answer_text}
        #Is the model's predict answer correct? Output 'True' or 'False' only.
        #"""
        #judge_result = getOutput(judge)
        
        #if 'True' in judge_result:
        answer_score = 1
        answer_score_scaled = 1
        #print(answer_text)
        #print(ground_truth)
        print("  Content validation: FULL MATCH")
        
    else:
        answer_score = 0
        answer_score_scaled = 0
        print("  Content validation: MISMATCH")
    #else:
        #answer_score = -2
        #answer_score_scaled = 0
        #print( "Fail to parse answer")

    total_score = answer_score
    print("\n" + "-"*80)
    print(f" Final Score ".center(80, '-'))
    print(f"  Answer: {answer_score}")
    print(f"  Total: {total_score}")
    print("="*80 + "\n")
    
    #output = {
        #"score": total_score,
        #"extra_info": {
            #"outcome_score": answer_score_scaled,
            #"format_score": format_score_scaled
        #}
    #}

    output = {
        "score": total_score,
        "extra_info": {
            "outcome_score": answer_score_scaled
        }
    }

    return output
'''