import json
import logging
import os
import re

from env import EnvironmentResult, Sample, TaskEnvironment

logger = logging.getLogger(__name__)


class FormulaEnvironment(TaskEnvironment):
    def __init__(self):
        pass

    def get_primary_metric_name(self) -> str:
        """Return the primary metric name for Formula environment."""
        return "accuracy"

    async def get_generator_prompt(self, sample: Sample, playbook_context: str) -> str:
        """
        Build a specialized prompt for Formula financial reasoning problems.
        
        Args:
            sample: The Formula financial reasoning sample
            playbook_context: Retrieved context/instructions for this sample
        
        Returns:
            Formatted prompt string optimized for financial reasoning
        """
        prompt = f"""You are an expert financial analyst specializing in quantitative reasoning and formula-based calculations.

Financial Reasoning Problem:
{sample.question}

Strategic Context:
{playbook_context}

Instructions:
- Read the problem carefully and identify all relevant financial concepts and formulas
- Extract all numerical values and their units from the problem statement
- Determine which formula or calculation method applies to this problem
- Show your step-by-step calculation process clearly
- Pay attention to units and conversions (e.g., millions, percentages, ratios)
- Round to the nearest hundredth if necessary (2 decimal places)
- Your answer should be a plain floating point number

You MUST respond with a valid JSON object containing exactly two fields:
1. "reasoning": Your detailed step-by-step financial analysis
2. "final_answer": The numerical answer ONLY as a plain floating point number (e.g., "0.16" or "5000000.0")

Example response format:
{{
  "reasoning": "Your step-by-step financial analysis... (less than 200 words in most cases)",
  "final_answer": "0.16"
}}"""
        return prompt

    def load_samples(self, path: str, limit: int = 10, random_sample: bool = False, shuffle: bool = False) -> list[Sample]:
        """Load samples from the benchmark.

        Args:
            path: Path to the data file to load
            limit: Maximum number of samples to load
            random_sample: If True, randomly sample limit items; if False, take first limit items
            shuffle: If True, shuffle the order of loaded samples (useful for mini-batching)

        Returns:
            List of Sample objects
        """
        import random
        
        # First, load all samples (or up to limit if not random sampling)
        all_samples = []
        with open(path, encoding="utf-8") as f:
            for i, row in enumerate(f):
                if not random_sample and limit is not None and i >= limit:
                    break
                data = json.loads(row)
                sample = Sample(
                    id=i,
                    question=data["question"],
                    context="",  # Formula dataset has no context
                    ground_truth=data["target"],
                    extras={},
                )
                all_samples.append(sample)
        
        # If random sampling and we have more samples than limit, randomly sample
        if random_sample and limit is not None and len(all_samples) > limit:
            samples = random.sample(all_samples, limit)
        else:
            samples = all_samples
        
        # Shuffle if requested (affects order for mini-batching)
        if shuffle:
            random.shuffle(samples)
        
        return samples

    def _formula_answer_is_correct(self, predicted: str, ground_truth: str) -> bool:
        """
        Formula dataset specific answer correctness check.

        ⚠️ Checks if predicted value is close to ground truth with rounding tolerance.

            Allows absolute tolerance of 0.01 to handle rounding issues
            where ground_truth might have been rounded incorrectly (e.g., 0.075 -> 0.07).
            This setup adapts to how ACE (Zhang et al., https://arxiv.org/pdf/2510.04618) preprocesses the dataset (adding "Your answer should be a plain floating point number, round to the nearest hundredth if necessary. Do the necessary conversions, for example 5 million should be 5000000.0. ").

            Only applies 0.01 tolerance when the last digit of either string is in the hundredth place (0.01).

        ⚠️ Checks if predicted value matches ground truth when converted between
        percentage and ratio formats (e.g., 16.0 vs 0.16, or 0.15 vs 15.0).
        """
        # Clean predicted value: remove all characters except digits, decimal point, and minus sign
        # Note: - must be escaped or placed at start/end of character class to avoid range interpretation
        cleaned_predicted = re.sub(r"[^0-9.\-]", "", predicted)
        cleaned_predicted = cleaned_predicted.replace(",", "")

        # Handle empty or invalid predicted values
        if not cleaned_predicted or cleaned_predicted in ("", ".", "-", "-."):
            return False

        try:
            pred_float = float(cleaned_predicted)
        except ValueError:
            return False

        gt_float = float(ground_truth)

        if pred_float == gt_float:
            return True

        def _is_close_with_rounding_tolerance(
            pred_str: str, gt_str: str, pred_float: float, gt_float: float
        ) -> bool:
            """Check if two floats are close, accounting for rounding errors.

            Allows absolute tolerance of 0.01 to handle rounding issues
            where ground_truth might have been rounded incorrectly (e.g., 0.075 -> 0.07).
            This setup adapts to how ACE (Zhang et al., https://arxiv.org/pdf/2510.04618) preprocesses the dataset (adding "Your answer should be a plain floating point number, round to the nearest hundredth if necessary. Do the necessary conversions, for example 5 million should be 5000000.0. ").

            Only applies 0.01 tolerance when at least one of the numbers has its last digit in the hundredth place (0.01).
            This handles cases where ground_truth might have more decimal places but should be rounded to hundredths.
            """

            def has_exactly_two_decimal_places(num_str: str) -> bool:
                """Check if a number string has exactly 2 decimal places.

                Returns True if the number has exactly 2 decimal places (e.g., "0.07", "16.00").
                This is used to determine when rounding tolerance should be applied.
                """
                # Remove all non-digit, non-decimal, non-minus characters
                cleaned = re.sub(r"[^0-9.\-]", "", num_str)
                cleaned = cleaned.replace(",", "")

                # Find decimal point position
                if "." not in cleaned:
                    return False  # No decimal point

                # Split by decimal point
                parts = cleaned.split(".")
                if len(parts) != 2:
                    return False

                decimal_part = parts[1]
                # Check if decimal part has exactly 2 digits
                return len(decimal_part) == 2

            abs_diff = abs(pred_float - gt_float)

            # Apply tolerance only when:
            # 1. Both numbers have exactly 2 decimal places (e.g., "0.07" vs "0.08")
            # 2. pred has exactly 2 decimal places AND gt has 2+ decimal places
            #    (e.g., "0.07" vs "0.075") - handles rounding issues where ground_truth
            #    might have been rounded incorrectly
            # We don't apply tolerance if one number has only 1 decimal place (e.g., "0.1" vs "0.11") or predicted has more than 2 decimal places
            pred_has_2_decimals = has_exactly_two_decimal_places(pred_str)
            gt_has_2_decimals = has_exactly_two_decimal_places(gt_str)

            # Check if ground_truth has 2+ decimal places
            gt_has_2plus_decimals = False
            if "." in gt_str:
                cleaned_gt = re.sub(r"[^0-9.\-]", "", gt_str).replace(",", "")
                parts = cleaned_gt.split(".")
                if len(parts) == 2 and len(parts[1]) >= 2:
                    gt_has_2plus_decimals = True

            # Use small epsilon to handle floating point precision issues
            # e.g., abs(0.48 - 0.47) = 0.010000000000000009, which is slightly > 0.01
            if abs_diff <= 0.01 + 1e-10 and (pred_has_2_decimals and gt_has_2_decimals):
                return True

            if abs_diff <= 0.005 + 1e-10 and (
                pred_has_2_decimals and gt_has_2plus_decimals
            ):
                return True

            return False

        # Check #1:
        if _is_close_with_rounding_tolerance(
            predicted, ground_truth, pred_float, gt_float
        ):
            logger.warning(
                f"⚠️⭕️ is_close_with_rounding_tolerance triggered: predicted={pred_float}, ground_truth={gt_float}"
            )
            return True

        # Check #2:
        # ⚠️ handle percentage/ratio format conversion, due to the inconsistent format in the original dataset
        # Check if predicted value matches when converted (divide or multiply by 100)
        # This handles cases like: 16.0 vs 0.16, 0.15 vs 15.0, 300.0 vs 3.0
        # Use floating point comparison with tolerance to avoid precision issues
        if pred_float / 100.0 == gt_float or round(pred_float / 100.0, 2) == gt_float:
            logger.warning(
                f"⚠️🔺 percentage/ratio conversion triggered: predicted={pred_float}, ground_truth={gt_float} (pred/100=gt)"
            )
            return True
        if pred_float * 100.0 == gt_float:
            logger.warning(
                f"⚠️🔺 percentage/ratio conversion triggered: predicted={pred_float}, ground_truth={gt_float} (pred*100=gt)"
            )
            return True

        return False

    async def aevaluate(self, sample: Sample, generator_output: str) -> EnvironmentResult:
        ground_truth_str = sample.ground_truth

        is_correct = self._formula_answer_is_correct(generator_output, ground_truth_str)

        accuracy = 1.0 if is_correct else 0.0

        if accuracy == 1.0:
            feedback = "Predicted answer matches ground truth"
        else:
            feedback = "Predicted answer does not match ground truth"

        return EnvironmentResult(
            feedback=feedback,
            ground_truth=sample.ground_truth,
            metrics={"accuracy": accuracy},
        )