import ast
import logging
import re
import traceback
from typing import Any

import numpy as np
from sympy import Rational

from tasks.base import Task

LOGGER = logging.getLogger('MINT')


class ReasoningTask(Task):
    task_name = 'reasoning'

    def __init__(self, id: str, prompt: str, reference: str, **kwargs):
        super().__init__(**kwargs)
        self._id = id
        self._prompt = prompt.strip()
        self._reference = str(reference).strip().lower()

    def extract_answer(self, solution: str) -> str | None:
        """Extract the answer from the given solution."""
        return solution.lower().strip()

    def compare_w_digits(self, reference: str, answer: str) -> bool:
        """Compare the reference and answer with digits."""
        # if reference can and answer can both be converted to floats by float()
        try:
            float(reference)
            float(answer)
            return abs(float(reference) - float(answer)) <= 0.05 * abs(float(reference))
        except ValueError:
            return reference in answer
        except Exception:
            raise ValueError(f'Cannot compare {reference} and {answer}')

    def success(self, solution: str) -> bool:
        answer = self.extract_answer(solution)
        return self.compare_w_digits(self._reference, answer)


class MultipleChoiceTask(Task):
    """Subclass of Task for multiple choice tasks."""

    task_name = 'reasoning'

    def __init__(self, id, prompt: str, reference: str, **kwargs):
        super().__init__(**kwargs)
        self._id = id
        self.hide_options = kwargs.get('hide_options', False)
        if self.hide_options:
            self._prompt = prompt.split('Options:')[0].strip()
        else:
            self._prompt = prompt
        self._reference = reference.strip().lower()
        self._options = self.extract_options(prompt)
        # if all options can be converted to float, strictly perform hide options
        try:
            for option in self._options.values():
                float(option)
            self.hide_options = True
        except ValueError:
            pass
        self.metadata.update({'options': self._options})

    def extract_answer(self, solution: str) -> str | None:
        # Extract the selected option from the solution
        solution = solution.lower().strip()
        for letter in 'abcdefghijklmnopqrstuvwxyz':
            if f'{letter})' in solution or f'{letter} )' in solution:
                print('SOLUTION', letter)
                return letter
            else:
                print('SOLUTION', solution)
                return solution

    def compare_w_digits(self, reference: str, answer: str) -> bool:
        if reference.isdigit() and answer.isdigit():
            return abs(float(reference) - float(answer)) <= 0.05 * float(reference)
        else:
            return reference in answer

    def success(self, solution: str) -> bool:
        answer = self.extract_answer(solution)
        if self.compare_w_digits(self._reference, answer):
            return True
        else:
            correct_option = self._options[self._reference]
            wrong_option_list = list(self._options.values())
            print('OPTIONS', correct_option, wrong_option_list)
            print('ANSWER', answer)
            for i in wrong_option_list:
                if i in correct_option:
                    wrong_option_list.remove(i)
            for i in wrong_option_list:
                if self.compare_w_digits(i, answer) or (i in answer):
                    return False
            if self.compare_w_digits(correct_option, answer) or (
                correct_option in answer
            ):
                return True
            else:
                return False

    def extract_options(self, prompt: str) -> dict:
        # Find the possible option separators (comma, semicolon, or parentheses)
        prompt = prompt.split('Options: ')[-1]
        # Extract the options using the delimiter
        options_match = prompt.split(' , ')
        options = {}
        for i in range(len(options_match)):
            option = options_match[i].strip("[]' ")
            option = option.split(')')
            letter = option[0].lower().strip()
            content = (
                option[1]
                .lower()
                .strip('.')
                .replace('. Which option is correct?', '')
                .replace('. Which one is correct?', '')
                .strip()
            )
            options.update({letter: content})
        return options


# ==== TheoremQA ====


def compare_two_numbers(p, gt):
    if isinstance(p, (int, float)):
        pass
    elif isinstance(p, (bool, complex, dict, list, str, tuple)):
        return False
    else:
        raise ValueError(p)

    if isinstance(gt, float):
        return within_eps(pred=p, gt=gt)
    else:
        return round(p) == gt


def compare_two_list(pred, gt):
    if not isinstance(pred, list):
        return False
    elif len(pred) != len(gt):
        return False
    elif any([not isinstance(x, (int, float)) for x in pred]):
        return False
    else:
        pred = sorted(pred)
        gt = sorted(gt)
        return all([compare_two_numbers(p, g) for p, g in zip(pred, gt)])


def within_eps(pred: float, gt: float):
    eps = abs(gt) * 0.04
    if pred >= gt - eps and pred <= gt + eps:
        return True
    else:
        return False


def parse_number_list(s: str):
    # Check if the string is a valid list by trying to parse it
    parsed_list = ast.literal_eval(s)
    return parsed_list


def is_number(string):
    pattern = r'^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$'
    match = re.match(pattern, string)
    return bool(match)


def is_scientific_number(string):
    pattern = r'^[-+]?\d+(\.\d+)?e[-]?\d+$'
    match = re.match(pattern, string)
    return bool(match)


def contain_num_and_str(string):
    pattern_str = r'[a-zA-Z]'
    pattern_num = r'[0-9]'
    return bool(re.search(pattern_str, string) and re.search(pattern_num, string))


class TheoremqaTask(Task):
    task_name = 'reasoning'

    def __init__(self, id: str, prompt: str, reference: str, **kwargs):
        super().__init__(**kwargs)
        self._id = id
        self._prompt = (
            'Answer the following question with a number, a list of numbers or True or False. '
            + prompt.strip()
        )
        self._reference = reference
        self._answer_type = kwargs.get('answer_type')

    def extract_answer(self, solution: str) -> Any:
        """Extract the answer from the given solution."""
        prediction = solution
        # Following the preprocessing steps from TheoremQA
        # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L170

        # Preprocessing the string [Stage 1]
        if not isinstance(prediction, str):
            prediction = str(prediction) if prediction is not None else '0'

        # Replace special tokens
        if '=' in prediction:
            prediction = prediction.split('=')[-1].strip()
        if '≈' in prediction:
            prediction = prediction.split('≈')[-1].strip()
        if '`' in prediction:
            prediction = prediction.replace('`', '')
        if '$' in prediction:
            prediction = prediction.replace('$', '')
        if '°' in prediction:
            prediction = prediction.replace('°', '')

        # Detect the boolean keyword in the generation
        if prediction in ('true', 'yes', 'false', 'no'):
            if prediction in ('true', 'yes'):
                prediction = 'True'
            else:
                prediction = 'False'
        if 'True' in prediction or 'False' in prediction:
            prediction = 'True' if 'True' in prediction else 'False'

        # Detect the approximation keyword
        if 'approximately' in prediction:
            prediction = prediction.replace('approximately', '').strip()
        if ' or ' in prediction:
            prediction = prediction.split(' or ')[0]

        # Drop the units before and after the number
        if re.match(r'[-+]?(?:[\d,]*\.*\d+) [^0-9 ]+$', prediction):
            prediction = re.search(
                r'([-+]?(?:[\d,]*\.*\d+)) [^0-9 ]+$', prediction
            ).group(1)
        if re.match(r'[^0-9 ]+ [-+]?(?:[\d,]*\.*\d+)$', prediction):
            prediction = re.search(
                r'[^0-9 ]+ ([-+]?(?:[\d,]*\.*\d+))$', prediction
            ).group(1)
        if re.match(r'[-+]?(?:[\d,]*\.*\d+)[^\d]{1,2}$', prediction):
            prediction = re.search(
                r'([-+]?(?:[\d,]*\.*\d+))[^\d]{1,2}$', prediction
            ).group(1)
        if re.match(r'[^-+\d]{1,2}(?:[\d,]*\.*\d+)$', prediction):
            prediction = re.search(
                r'[^-+\d]{1,2}((?:[\d,]*\.*\d+))$', prediction
            ).group(1)

        # Preprocessing the number [Stage 1]
        if '10^' in prediction:
            prediction = re.sub(r'10\^(-?\d+)', r'math.pow(10, \1)', prediction)
        if ' x ' in prediction:
            prediction = prediction.replace(' x ', '*')
        if ' × ' in prediction:
            prediction = prediction.replace(' × ', '*')
        if is_number(prediction):
            prediction = prediction.replace(',', '')

        # Preprocessing the option [Stage 3]
        if (
            'a)' in prediction
            or 'a )' in prediction
            or prediction.lower().strip() == 'a'
        ):
            prediction = '(a)'
        if (
            'b)' in prediction
            or 'b )' in prediction
            or prediction.lower().strip() == 'b'
        ):
            prediction = '(b)'
        if (
            'c)' in prediction
            or 'c )' in prediction
            or prediction.lower().strip() == 'c'
        ):
            prediction = '(c)'
        if (
            'd)' in prediction
            or 'd )' in prediction
            or prediction.lower().strip() == 'd'
        ):
            prediction = '(d)'

        if (
            '(a)' in prediction
            or '(b)' in prediction
            or '(c)' in prediction
            or '(d)' in prediction
        ):
            prediction = '"' + re.search(r'\([a-d]\)', prediction).group(0) + '"'

        # If the prediction is empty, use dummy '0'
        if not prediction:
            prediction = '0'

        # Converting the string answer to a number/list/bool/option
        try:
            prediction = eval(prediction)
        except Exception:
            LOGGER.warning(
                f'[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}'
            )
            return None  # failed to convert the answer

        # Performing common type conversion
        if isinstance(prediction, (set, tuple)):
            prediction = list(prediction)
            if isinstance(prediction[0], complex):
                prediction = [tmp.real for tmp in prediction]
            elif isinstance(prediction[0], Rational):
                prediction = [float(tmp) for tmp in prediction]
        elif isinstance(prediction, np.ndarray):
            prediction = prediction.tolist()
        else:
            if isinstance(prediction, complex):
                prediction = prediction.real
            elif isinstance(prediction, Rational):
                prediction = float(prediction)

        return prediction

    def success(self, solution: str) -> bool:
        """This checks whether the given solution can complete the current task."""
        # Follow the implementation from TheoremQA
        # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1
        prediction = self.extract_answer(solution)
        LOGGER.info(f'TheoremQA Parsed Prediction: {prediction}')
        answer_type = self._answer_type
        gt = self.extract_answer(self.reference)

        if isinstance(prediction, (str, int, float, list)):
            # Comparing prediction against the reference
            if answer_type in ['bool', 'option', 'Option']:
                cur_correct = int(prediction == f'({gt})') or int(prediction == gt)
            elif answer_type == 'integer':
                cur_correct = int(compare_two_numbers(prediction, gt))
            elif answer_type == 'float':
                cur_correct = int(compare_two_numbers(prediction, gt))
            elif answer_type in ['list of integer', 'list of float']:
                cur_correct = int(compare_two_list(prediction, gt))
        else:
            cur_correct = 0
        return bool(cur_correct)
