"""Flexible answer matcher that can handle various answer formats."""

import ast
import csv
import json
import logging
import re
from pathlib import Path

import pandas as pd
from rich.logging import RichHandler

logging.basicConfig(
    level=logging.INFO,
    format='%(message)s',
    datefmt='[%X]',
    handlers=[RichHandler(markup=True)],
)


def load_country_code_map():
    code_map = {}
    try:
        with Path('resources', 'un_m49_cleaned.csv').open(encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                code = row['country_code'].strip().upper()
                name = row['country_name'].strip()
                if code and name:
                    code_map[code] = name
    except Exception as e:
        logging.warning(f'Could not load country code map: {e}')
    return code_map


COUNTRY_CODE_MAP = load_country_code_map()


class Matcher:
    """Flexible answer matcher using answer_format from metadata."""

    def __init__(self, percent_tolerance: float = 0.01):
        """Initialize the Matcher."""
        self.percent_tolerance = percent_tolerance

    def extract_final_answer(self, messages):
        """Extract the final answer from a list of messages (dicts)."""
        final_answer = None
        for message in messages:
            if message.get('role') == 'assistant' and message.get('tool_calls'):
                for tool_call in message['tool_calls']:
                    if tool_call.get('function', {}).get('name') == 'final_answer':
                        arguments = tool_call['function']['arguments']
                        # Fix: Only parse if arguments is a string
                        if isinstance(arguments, str):
                            parsed_args = json.loads(arguments)
                        elif isinstance(arguments, dict):
                            parsed_args = arguments
                        else:
                            parsed_args = {}
                        final_answer = parsed_args.get('answer')
                        break
            if final_answer is not None:
                break
        return final_answer

    def match(
        self,
        pred,
        gold: str,
        answer_format: str | None = None,
    ) -> tuple[bool, float]:
        """Match predicted and gold answers using the specified answer format.

        If pred is a list of messages, extract the final answer.
        """
        # If pred is a list of dicts (messages), extract the final answer
        if isinstance(pred, list) and pred and isinstance(pred[0], dict):
            pred = self.extract_final_answer(pred)
        # Accept both 'list' and 'list[str]' as list formats
        if answer_format in ('list', 'list[str]', 'list[int]', 'list[float]'):
            return self.match_list(pred, gold)
        elif answer_format == 'float':
            return self.match_float(pred, gold)
        elif answer_format == 'bool':
            return self.match_bool(pred, gold)
        elif answer_format == 'int':
            return self.match_int(pred, gold)
        elif answer_format == 'str':
            return self.match_str(pred, gold)
        else:
            return self.match_fallback(pred, gold)

    def match_float(
        self,
        pred: str,
        gold: str,
    ) -> tuple[bool, float]:
        """Match float values with percent error.

        Returns
        -------
        (bool, float)
            Tuple of (correct, percent_error).

        """
        logging.info(f'🔬 Matcher().match_float(pred={pred!r}, gold={gold!r})')

        # Attempt to parse both pred and gold as floats
        try:
            pred_f = float(ast.literal_eval(pred))
            gold_f = float(gold)
            logging.info(f'🔬 Parsed pred: {pred_f}')
            logging.info(f'🔬 Parsed gold: {gold_f}')
            percent_error = abs(pred_f - gold_f) * 100 if gold_f == 0 else abs(pred_f - gold_f) / abs(gold_f) * 100
            correct = percent_error <= 0.01

        # If parsing fails, try to handle common cases
        except Exception as e:
            logging.warning(f'🔬 Exception parsing float values: {e}')

            # Fallback: check if gold value as string is present in pred string
            gold_str = str(gold).strip()
            if gold_str in str(pred):
                logging.info(f'✅ Correct: gold value {gold_str!r} found in prediction string (fallback).')
                return True, 0.0

            # Further fallback: extract any float from pred and compare
            try:
                gold_f = float(gold)
                float_pattern = r'[-+]?\d*\.\d+|\d+'
                found_floats = [float(x) for x in re.findall(float_pattern, str(pred.replace(',', '')))]
                for f in found_floats:
                    percent_error = abs(f - gold_f) * 100 if gold_f == 0 else abs(f - gold_f) / abs(gold_f) * 100
                    if percent_error <= 0.01:
                        logging.info(f'✅ Found matching float {f} in prediction string (regex fallback).')
                        return True, 0.0
                logging.warning('❌ No matching float found in prediction string (regex fallback).')
            except Exception as e2:
                logging.warning(f'🔬 Exception in regex float extraction: {e2}')

            percent_error = 100.0
            correct = False
            pred_f = None
            gold_f = None

        if correct and percent_error == 0.0:
            logging.info('✅ Correct: exact match.')
        elif correct:
            logging.info(f'✅ Correct within {self.percent_tolerance}% tolerance. Percent error: {percent_error:.5f}%')
        else:
            logging.warning(f'❌ Incorrect. Answer {pred_f!r} differs from gold {gold_f!r} by {percent_error:.5f}%')

        return correct, percent_error

    def match_bool(self, pred: str, gold: str) -> tuple[bool, float]:
        """Match boolean values.

        Returns
        -------
        (bool, float)
            Tuple of (correct, percent_error).

        """
        logging.info(f'🔬 Matcher().match_bool(pred={pred!r}, gold={gold!r})')
        bool_map = {
            'true': True,
            'false': False,
            'Yes': True,
            'No': False,
            'yes': True,
            'no': False,
        }

        try:
            pred_val = ast.literal_eval(pred)
        except Exception as e:
            logging.warning(f'🔬 Exception parsing pred: {e}. Falling back to mapping.')
            pred_val = pred

        gold_val = gold

        pred_bool = bool_map.get(str(pred_val).strip().lower(), pred_val)
        logging.info(f"🔬 Parsed pred '{pred}' -> {pred_bool}")

        correct = bool(pred_bool) == bool(gold_val)
        if correct:
            logging.info('✅ Correct boolean match.')
        else:
            logging.warning(f'❌ Incorrect boolean match. Received: {pred_bool!r}, expected: {gold_val!r}')

        percent_error = 0.0 if correct else 100.0
        return correct, percent_error

    def match_list(
        self,
        pred: str,
        gold: str,
    ) -> tuple[bool, float]:
        """Match list-formatted answers, e.g., "['a', 'b']", or comma-separated strings e.g., "a, b".

        Returns
        -------
        (bool, float)
            Tuple of (correct, percent_error). If not correct, percent_error is delta.

        """
        logging.info(f'🔬 Matcher().match_list(pred={pred!r}, gold={gold!r})')

        if not gold:
            logging.warning('🔬 Gold answer is empty or None. Cannot match lists.')
            return False, 100.0

        # Parse pred to list
        try:
            pred_list = ast.literal_eval(pred.strip()) if isinstance(pred, str) else pred
            if not isinstance(pred_list, list):
                logging.warning('🔬 Cannot parse list from predicted answer.')
                return False, 100.0

        except Exception as e:
            logging.warning(f'🔬 Failed to parse pred: {e}. Using fallback parsing.')
            pred_list = [item.strip() for item in str(pred).strip('[](){}').split(',') if item.strip()]

        # Gold may already be a list, or a string representation of a list
        if isinstance(gold, list):
            gold_list = gold
        else:
            try:
                gold_list = ast.literal_eval(str(gold).strip())
            except Exception as e:
                logging.warning(f'🔬 Failed to parse gold: {e}. Using fallback parsing.')
                gold_list = [item.strip() for item in str(gold).strip('[](){}').split(',') if item.strip()]

        # --- Country code fallback for lists ---
        def normalize_country(val):
            val = str(val).strip()
            # If 3-letter uppercase, try to map to country name
            if len(val) == 3 and val.isupper():
                mapped = COUNTRY_CODE_MAP.get(val)
                if mapped:
                    return mapped.strip().lower()
            return val.lower()

        pred_list_norm = [normalize_country(x) for x in pred_list]
        gold_list_norm = [normalize_country(x) for x in gold_list]

        # Log parsed values
        logging.info(f'🔬 Parsed pred_list: {pred_list}')
        logging.info(f'🔬 Parsed gold_list: {gold_list}')
        logging.info(f'🔬 Normalized pred_list: {pred_list_norm}')
        logging.info(f'🔬 Normalized gold_list: {gold_list_norm}')

        # Compare as sets (order-insensitive)
        if isinstance(pred_list_norm, list) and isinstance(gold_list_norm, list):
            pred_set = set(pred_list_norm)
            gold_set = set(gold_list_norm)
            correct = pred_set == gold_set
            percent_error = 0.0 if correct else float(abs(len(pred_list_norm) - len(gold_list_norm)))

            if correct:
                logging.info('✅ Correct set match.')
            else:
                missing = gold_set - pred_set
                extra = pred_set - gold_set
                logging.warning('❌ Set mismatch')
                logging.warning(f'🔬 Correct: {gold_set & pred_set}')
                logging.warning(f'🔬 Missing: {missing}')
                logging.warning(f'🔬 Extra: {extra}')

            return correct, percent_error

        else:
            logging.warning('🔬 One or both values are not lists after parsing.')
            return False, 100.0

    def match_int(self, pred: str, gold: str) -> tuple[bool, float]:
        """Match integer values.

        Returns
        -------
        (bool, float)
            Tuple of (correct, percent_error). If not correct, percent_error is delta.

        """
        logging.info(f'🔬 Matcher().match_int(pred={pred!r}, gold={gold!r})')

        try:
            pred_i = int(ast.literal_eval(pred))
            gold_i = int(gold)
            logging.info(f'🔬 Parsed pred: {pred_i}')
            logging.info(f'🔬 Parsed gold: {gold_i}')
            correct = pred_i == gold_i
            percent_error = 0.0 if correct else float(pred_i - gold_i)
            if correct:
                logging.info('✅ Correct match.')
            else:
                logging.warning(f'❌ Incorrect match. Predicted: {pred_i}, Gold: {gold_i}, Delta: {percent_error}')

            return correct, percent_error

        except Exception as e:
            logging.warning(f'🔬 Exception parsing int values: {e}')
            # Fallback: search for any integer in pred string and compare
            try:
                gold_i = int(gold)
                int_pattern = r'-?\d+'
                found_ints = [int(x) for x in re.findall(int_pattern, str(pred))]
                for i in found_ints:
                    if i == gold_i:
                        logging.info(f'✅ Found matching int {i} in prediction string (regex fallback).')
                        return True, 0.0
                logging.warning('❌ No matching int found in prediction string (regex fallback).')
            except Exception as e2:
                logging.warning(f'🔬 Exception in regex int extraction: {e2}')
            return False, 100.0

    def match_str(self, pred: str, gold: str) -> tuple[bool, float]:
        """Match string values (case-insensitive, whitespace-stripped).

        Returns
        -------
        (bool, float)
            Tuple of (correct, percent_error).

        """
        logging.info(f'🔬 Matcher().match_str(pred={pred!r}, gold={gold!r})')
        pred_str = str(pred).strip().lower()
        gold_str = str(gold).strip().lower()
        logging.info(f"🔬 Parsed pred: '{pred_str}'")
        logging.info(f"🔬 Parsed gold: '{gold_str}'")
        correct = pred_str == gold_str
        percent_error = 0.0 if correct else 100.0
        if correct:
            logging.info('✅ Exact string match.')
        else:
            # Fallback: check if gold string is present in pred string
            if gold_str and gold_str in pred_str:
                logging.info('✅ Gold string found in prediction string (fallback).')
                return True, 0.0

            # Fallback: if pred is a 3-letter uppercase country code, try to convert to country name
            pred_code = str(pred).strip()
            if len(pred_code) == 3 and pred_code.isupper():
                country_name = COUNTRY_CODE_MAP.get(pred_code)
                if country_name:
                    logging.info(f"🔬 Fallback: Converted code '{pred_code}' to country name '{country_name}'")
                    if country_name.strip().lower() == gold_str:
                        logging.info('✅ Pred country code matches gold country name.')
                        return True, 0.0
                    # Also try if gold is a code and pred is a name
                    if gold_str in COUNTRY_CODE_MAP and COUNTRY_CODE_MAP[gold_str.upper()].strip().lower() == pred_str:
                        logging.info('✅ Gold code matches predicted country name.')
                        return True, 0.0

            logging.warning(f'❌ Incorrect string match. Predicted: {pred_str!r}, Gold: {gold_str!r}')

        return correct, percent_error

    def match_fallback(self, pred: str, gold: str) -> tuple[bool, float]:
        """Try all matchers in order until one matches exactly.

        Returns
        -------
        (bool, float)
            Tuple of (correct, percent_error).

        """
        logging.info(f'🔬 Matcher() match_fallback | Trying all matchers for pred: {pred!r} | gold: {gold!r}')
        for fmt in ('float', 'bool', 'list', 'int', 'str'):
            correct, percent_error = self.match(pred, gold, fmt)
            if correct:
                logging.info(f'🔬 Matcher() match_fallback | Matched using format: {fmt}')
                return True, percent_error
        logging.warning('🔬 Matcher() match_fallback | No matcher succeeded.')
        return False, 100.0

    def match_row(self, row: pd.Series):
        """Match a DataFrame row by extracting messages, gold answer, and answer format.

        Parameters
        ----------
        row : pd.Series
            A row from a DataFrame containing at least 'messages', 'answer', and optionally 'metadata'.

        Returns
        -------
        (bool, float)
            Tuple of (correct, percent_error).

        """
        messages = row.get('messages')
        gold = row.get('answer')
        answer_format = None
        metadata = row.get('metadata')
        if isinstance(metadata, dict):
            answer_format = metadata.get('answer_format')
        return self.match(messages, gold, answer_format)


if __name__ == '__main__':
    # Example usage
    matcher = Matcher()
    pred = '4325.0'
    gold = "['Canada', 'USA']"
    correct, percent_error = matcher.match(pred, gold, 'list')
