import re
import json
from scipy.optimize import linear_sum_assignment
import numpy as np
from typing import Any, Dict, List
import pdb
import random
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer, util

from rouge_score import rouge_scorer

_ROUGE_SCORER = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
SENTENCE_MODEL = SentenceTransformer('all-MiniLM-L6-v2')

import re

def extract_answer_text(predict_str: str) -> str:
    if not predict_str:
        return ""
    m = re.search(r'<answer>\s*(.*?)\s*</answer>', predict_str, re.DOTALL)
    if m:
        return m.group(1).strip()
    return predict_str.strip()

def calc_format_reward(predict_str: str) -> float:
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    match = re.search(pattern, predict_str, re.DOTALL)
    thinking_format_reward = 1.0 if match else 0.0

    def mmmu_answer_format(predict_str: str) -> float:
        answer_format_reward = 0.0
        try:
            answer_match = re.search(r'<answer>\s*(.*?)\s*</answer>', predict_str, re.DOTALL)
            if answer_match:
                answer_content = answer_match.group(1).strip()
                if re.fullmatch(r'^[A-Z]$', answer_content):
                    answer_format_reward = 1.0
        except Exception:
            pass
        return answer_format_reward

    def yes_or_no_format_reward(predict_str: str) -> float:
        answer_format_reward = 0.0
        try:
            # Extract content from the <answer> tag.
            answer_match = re.search(r'<answer>\s*(.*?)\s*</answer>', predict_str, re.DOTALL)
            if answer_match:
                answer_content = answer_match.group(1).strip().lower()
                # Check if the content is 'yes' or 'no'.
                if answer_content in ['yes', 'no']:
                    answer_format_reward = 1.0
        except Exception:
            pass
        return answer_format_reward
    
    # use mmmu for now
    # answer_format_reward = yes_or_no_format_reward(predict_str)

    answer_format_reward = mmmu_answer_format(predict_str)

    return thinking_format_reward + answer_format_reward

def calc_accuracy_reward(predict_str: str, ground_truth: str) -> float:
    try:
        gt_answer = ground_truth.strip().lower()
        
        answer_match = re.search(r'<answer>\s*(.*?)\s*</answer>', predict_str, re.DOTALL)
        
        if answer_match:
            pred_answer_text = answer_match.group(1).strip().lower();
            if pred_answer_text == gt_answer:
                return 1.0
        
        return 0.0
    except Exception:
        return 0.0

def calc_non_repeat_reward(predict_str: str) -> float:
    non_repeat_reward = 1.0 
    try:
        sentences = predict_str.split('.')
        sentences = [s.strip() for s in sentences if s.strip()]
        
        seen = set()
        repeats = 0
        
        for sentence in sentences:
            if sentence in seen:
                repeats += 1
            if repeats >=2:
                non_repeat_reward = 0
                break
            seen.add(sentence)
            
    except Exception:
        pass
    
    return non_repeat_reward

def get_thinking_text(predict_str: str) -> str:
    think_match = re.search(r'<think>(.*?)</think>', predict_str, re.DOTALL)
    return think_match.group(1).strip() if think_match else ""

def calc_fn_alignment_reward(attn_map: np.ndarray, gt_region: np.ndarray) -> float:
    if attn_map is None or gt_region is None:
        return 0.0
    
    attn_map = np.asarray(attn_map)
    gt_region = np.asarray(gt_region)
    
    gt_region = (gt_region > 0).astype(float)
    
    overlap = np.sum(attn_map * gt_region) / (np.sum(gt_region) + 1e-6)
    
    return float(np.clip(overlap, 0, 1))

def calc_fn_step_consistency_reward(predict_str: str, 
                                             repetition_threshold: float = 0.95) -> float:
    reasoning_text = get_thinking_text(predict_str)
    if not reasoning_text:
        return 1.0
    
    sentences = [s.strip() for s in reasoning_text.split('.') if s.strip()]
    
    if len(sentences) < 2:
        return 1.0

    try:
        embeddings = SENTENCE_MODEL.encode(sentences, convert_to_tensor=True)
        
        sim = util.cos_sim(embeddings[:-1], embeddings[1:])
        diag_sims = np.diag(sim.cpu().numpy())
        
        mean_sim = np.mean(diag_sims)
        coherence_reward = (mean_sim + 1) / 2
        
        repetition_count = np.sum(diag_sims > repetition_threshold)
        
        repetition_penalty = repetition_count / len(diag_sims)

        final_reward = max(0.0, coherence_reward - repetition_penalty)
        
        return float(final_reward)
        
    except Exception:
        # print("Error in computing step consistency reward.") 
        return 0.0

def calc_fn_length_penalty_reward(predict_str: str, target_len: int = 1024, alpha: float = 0.3) -> float:
    reasoning_text = get_thinking_text(predict_str)
    length = len(reasoning_text)
    
    if target_len <= 0:
        return 1.0

    penalty = abs(length - target_len) / target_len
    
    reward = np.exp(-alpha * penalty)
    return float(reward)

def compute_score(reward_inputs: List[Dict[str, Any]], has_ground_truth: bool = False, n: int = 8) -> Dict[str, float]:
    if not isinstance(reward_inputs, list):
        raise ValueError("Please use `reward_type=batch` for dapo reward function.")

    W_FORMAT = 1.0
    W_ACCURACY = 3.0
    W_NON_REPEAT = 2.0
    W_LENGTH = 1.0
    W_CONSISTENCY = 0.2

    scores = []
    temporal_scores = []

    for idx, reward_input in enumerate(reward_inputs):
        format_reward = calc_format_reward(reward_input["response"])
        accuracy_reward = calc_accuracy_reward(reward_input["response"], reward_input["ground_truth"])
        non_repeat_reward = calc_non_repeat_reward(reward_input["response"])
        
        length_reward = calc_fn_length_penalty_reward(reward_input["response"])
        consistency_reward = calc_fn_step_consistency_reward(reward_input["response"])

        reward = (W_FORMAT * format_reward +
                  W_ACCURACY * accuracy_reward +
                  W_NON_REPEAT * non_repeat_reward +
                  W_LENGTH * length_reward +
                  W_CONSISTENCY * consistency_reward)
  
        if has_ground_truth and (idx + 1) % (n + 1) == 0:
            if accuracy_reward <= max(temporal_scores, key=lambda x: x["accuracy"])["accuracy"]:
                accuracy_reward = float(np.mean([x["accuracy"] for x in temporal_scores]))
                format_reward = float(np.mean([x["format"] for x in temporal_scores]))
                non_repeat_reward = float(np.mean([x["non_repeat"] for x in temporal_scores]))
                length_reward = float(np.mean([x["length"] for x in temporal_scores]))
                consistency_reward = float(np.mean([x["consistency"] for x in temporal_scores]))

            reward = (W_FORMAT * format_reward +
                      W_ACCURACY * accuracy_reward +
                      W_NON_REPEAT * non_repeat_reward +
                      W_LENGTH * length_reward +
                      W_CONSISTENCY * consistency_reward)

            temporal_scores = []
        else:
            temporal_scores.append({
                                "overall": reward,
                                "format": format_reward,
                                "accuracy": accuracy_reward,
                                "non_repeat": non_repeat_reward,
                                "length": length_reward,
                                "consistency": consistency_reward,
                            })

        scores.append({
            "overall": reward,
            "format": format_reward,
            "accuracy": accuracy_reward,
            "non_repeat": non_repeat_reward,
            "length": length_reward,
            "consistency": consistency_reward,
        })
    return scores



if __name__ == "__main__":
    predict_str = """
<think>
think
</think>
<answer>
YeS
</answer>
"""
    ground_truth = """Yes"""
    print(predict_str)
    print(ground_truth)
    print(compute_score([{"response":predict_str, "ground_truth":ground_truth}]))
