from collections import Counter
from typing import List
import math
from verl.utils.reward_score.ttrl.auto_extract import auto_extract
from verl.utils.reward_score.ttrl.auto_verify import auto_verify
from collections import defaultdict
def is_inside(point, bbox):
    """Check if a point (x,y) lies inside a bbox (x1, y1, x2, y2)"""
    x, y = point
    x1, y1, x2, y2 = bbox
    return x1 <= x <= x2 and y1 <= y <= y2
def get_voted_grid_bbox(bboxes, grid_size=(16, 16)):
    # Filter out empty strings and invalid tuples
    def parse_bbox(item):
        if isinstance(item, tuple):
            return item if len(item) == 4 else None
        elif isinstance(item, str):
            try:
                # Remove parentheses and whitespace, then split by commas
                coords = list(map(float, ''.join(c for c in item if c in "0123456789.-,").split(',')))
                if len(coords) == 4:
                    return tuple(coords)
            except Exception:
                pass
        return None

    # Parse and filter out invalid bboxes
    valid_bboxes = []
    for bbox in bboxes:
        parsed = parse_bbox(bbox)
        if parsed is not None:
            valid_bboxes.append(parsed)

    if not valid_bboxes:
        return None

    # Extract all top-left and bottom-right points
    top_lefts = [(bbox[0], bbox[1]) for bbox in valid_bboxes]
    bottom_rights = [(bbox[2], bbox[3]) for bbox in valid_bboxes]

    # Global bounds
    all_points = [pt for pts in [top_lefts, bottom_rights] for pt in pts]
    min_x = min(p[0] for p in all_points)
    max_x = max(p[0] for p in all_points)
    min_y = min(p[1] for p in all_points)
    max_y = max(p[1] for p in all_points)

    width = max_x - min_x
    height = max_y - min_y

    rows, cols = grid_size
    cell_width = width / cols
    cell_height = height / rows

    # Function to determine grid cell index
    def point_to_cell(x, y):
        col = int((x - min_x) // cell_width)
        row = int((y - min_y) // cell_height)
        return min(rows - 1, max(0, row)), min(cols - 1, max(0, col))

    # Voting counters
    tl_votes = defaultdict(int)
    br_votes = defaultdict(int)

    # Count votes
    for x, y in top_lefts:
        cell = point_to_cell(x, y)
        tl_votes[cell] += 1

    for x, y in bottom_rights:
        cell = point_to_cell(x, y)
        br_votes[cell] += 1

    # Find peak voting cells
    peak_tl_cell = max(tl_votes.items(), key=lambda x: x[1])[0] if tl_votes else (0, 0)
    peak_br_cell = max(br_votes.items(), key=lambda x: x[1])[0] if br_votes else (rows - 1, cols - 1)

    # Function to convert cell index to bbox
    def cell_to_bbox(row, col):
        x1 = min_x + col * cell_width
        y1 = min_y + row * cell_height
        x2 = x1 + cell_width
        y2 = y1 + cell_height
        return (x1, y1, x2, y2)

    tl_bbox = cell_to_bbox(*peak_tl_cell)
    br_bbox = cell_to_bbox(*peak_br_cell)

    # Final bounding box
    final_bbox = (
        round(tl_bbox[0]),  # x1 (leftmost)
        round(tl_bbox[1]),  # y1 (topmost)
        round(br_bbox[2]),  # x2 (rightmost)
        round(br_bbox[3])   # y2 (bottommost)
    )
    majority_count = 0
    for bbox in valid_bboxes:
        x1, y1, x2, y2 = bbox
        tl_in = is_inside((x1, y1), final_bbox)
        br_in = is_inside((x2, y2), final_bbox)
        # You can change this to "tl_in and br_in" if you want full containment
        if tl_in or br_in:
            majority_count += 1
    

    return final_bbox, majority_count

def test_time_train_metrics(
    solutions: List[str],
    ground_truth: List[str],
    task="math", extra_info=None):
    
    assert len(solutions) == len(ground_truth), f"{len(solutions)} vs {len(ground_truth)}"

    assert len(set(ground_truth)) == 1, f"Ground truth is not unique: {ground_truth}"
    ground_truth = ground_truth[0]

    model_answers = auto_extract(task, solutions, extra_info=extra_info)
    counter = Counter(model_answers)
    total = len(model_answers)
    reward_p = [counter[ans] / total for ans in model_answers]


    entropy = 0.0
    for count in counter.values():
        probability = count / total
        if probability > 0:  # Avoid log(0)
            entropy -= probability * math.log(probability)
    
    if total > 1:
        max_entropy = math.log(len(counter))  # Max entropy for this many unique answers
        normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0.0
    else:
        normalized_entropy = 0.0
    
    estimated_label, majority_count = counter.most_common(1)[0]
    # estimated_label, majority_count= get_voted_grid_bbox(model_answers)
    
    hit_rate = 1.0 if auto_verify(task, [estimated_label], [ground_truth], extra_info=extra_info)[0][0] else 0.0
    majority_ratio = majority_count / len(solutions)
    # true_label_ratio = counter.get(ground_truth, 0) / len(solutions)

    rewards, _ = auto_verify(task, solutions, [estimated_label] * len(solutions), extra_info=extra_info)
    true_rewards, _ = auto_verify(task, solutions, [ground_truth] * len(solutions), extra_info=extra_info)
    rewards_en = [(r*1) - (0.75 * normalized_entropy) for r in reward_p]
    
    rewards_hit_rate = 0
    for reward, true_reward in zip(rewards, true_rewards):
        if reward == true_reward:
            rewards_hit_rate += 1
    # for reward, true_reward in zip(rewards, true_rewards):
    #     a=1-abs(reward-true_reward)
    #     # if reward == true_reward:
    #     rewards_hit_rate += a
    rewards_hit_rate = rewards_hit_rate / len(rewards)

    assert len(rewards) == len(solutions), f"{len(rewards)} vs {len(solutions)}"

    ttrl_metrics = {
        "label_accuracy": hit_rate,
        "reward_accuracy": rewards_hit_rate,
        "majority_ratio": majority_ratio,
        "ground_truth_ratio": sum(true_rewards) / len(true_rewards),
        "majority_voting_reward": sum(rewards) / len(rewards),
        f"pass@{len(solutions)}": 1.0 if sum(true_rewards) >= 1 else 0.0,
    }
    return rewards_en, ttrl_metrics

def post_test_time_train_metrics(
    solutions: List[str],
    ground_truth: List[str],
    pred_rewards: List,
    task="math", extra_info=None):
    assert len(solutions) == len(ground_truth), f"{len(solutions)} vs {len(ground_truth)}"
    assert len(solutions) == len(pred_rewards), f"{len(solutions)} vs {len(pred_rewards)}"
    assert len(set(ground_truth)) == 1, f"Ground truth is not unique: {ground_truth}"
    ground_truth = ground_truth[0]

    model_answers = auto_extract(task, solutions, extra_info=extra_info)

    # counter = Counter(model_answers)
    
    # true_label_ratio = counter.get(ground_truth, 0) / len(solutions)

    true_rewards, _ = auto_verify(task, solutions, [ground_truth] * len(solutions), extra_info=extra_info)

    # Compare pred_rewards with true_rewards to calculate reward hit rate
    rewards_hit_rate = sum(
        1 if pred == true else 0 for pred, true in zip(pred_rewards, true_rewards)
    ) / len(pred_rewards)

    # rewards_hit_rate = 0
    # for reward, true_reward in zip(pred_rewards, true_rewards):
    #     a=1-abs(reward-true_reward)
    #     # if reward == true_reward:
    #     rewards_hit_rate += a
    # rewards_hit_rate = rewards_hit_rate / len(pred_rewards)

    post_ttrl_metrics = {
        "post_reward_accuracy": rewards_hit_rate,
        "post_ground_truth_ratio": sum(true_rewards) / len(true_rewards),
        f"post_pass@{len(solutions)}": 1.0 if sum(true_rewards) > 0 else 0.0,
    }
    return post_ttrl_metrics