
from typing import List, Dict, Union
import re
from src.scoring.base import BaseScoringMethod
import numpy as np
import ast

class ClassifierScoringMethod(BaseScoringMethod):
    def create_prompt(self, interpretation: str, examples: List[str]) -> str:
        examples_text = "\n".join(
            f"Example {i}: {repr(example).replace('>><<', '').replace('<bos>', '').replace('<eos>', '').replace('<pad>', '')}" 
            for i, example in enumerate(examples)
        )
        return f"Latent interpretation: {interpretation}\n\nText examples:\n{examples_text}"

    def parse_response(self, response: str, batch_size: int) -> Union[List[int], None]:
        # Extract Python list from the response using regex
        list_str = re.findall(r"\[[^\]]*\]", response)
        try:
            if list_str:
                parsed = ast.literal_eval(list_str[-1].strip())
                if isinstance(parsed, list):
                    processed = [int(bool(x)) for x in parsed]  # Ensure values are 0 or 1
                    if len(processed) == batch_size:
                        return processed
                    return None  # Length doesn't match
        except (SyntaxError, ValueError):
            pass

        # Fallback: Extract binary patterns (e.g., 0 or 1) directly
        binary_pattern = re.findall(r"\b[01]\b", response)
        if binary_pattern and len(binary_pattern) == batch_size:
            return [int(x) for x in binary_pattern]

        # No matching pattern found with correct length
        return None

    def compute_metrics(self, predictions: List[int], labels: List[bool]) -> Dict[str, float]:
        predictions = np.array(predictions).astype(bool)
        labels = np.array(labels).astype(bool)

        tp = np.sum(predictions & labels)
        tn = np.sum(~predictions & ~labels)
        fp = np.sum(predictions & ~labels)
        fn = np.sum(~predictions & labels)

        def safe_div(a, b):
            return a / b if b > 0 else 0.0

        true_positive_rate = safe_div(tp, tp + fn)
        true_negative_rate = safe_div(tn, tn + fp)
        precision = safe_div(tp, tp + fp)
        recall = safe_div(tp, tp + fn)
        f1_score = safe_div(2 * tp, 2 * tp + fp + fn)
        balanced_accuracy = (true_positive_rate + true_negative_rate) / 2

        return {
            "true_positive_rate": true_positive_rate,
            "true_negative_rate": true_negative_rate,
            "balanced_accuracy": balanced_accuracy,
            "precision": precision,
            "recall": recall,
            "f1_score": f1_score,
            "support_positive": int(tp + fn),  # Total positive examples
            "support_negative": int(tn + fp),  # Total negative examples
            "total_examples": len(labels),  # Total examples
        }, balanced_accuracy