#!/usr/bin/env python3

from glob import glob
import json
import re
import numpy as np
from random import randint


class Example:
    def __init__(self, query, answer, solution=None, info=None):
        self.query, self.answer, self.solution, self.info = \
                query, answer, solution, info


class Dataset:
    @staticmethod
    def load(path, option=None):
        if 'MATH' in path:
            return MATH(path, option)
        elif 'self' in path:
            return GSM8K(path, option)
        elif 'GSM' in path:
            return GSM8K(path, option)
        elif 'ARC' in path:
            return ARC(path, option)
        elif 'math' in path:
            return Math(path, option)
        else:
            return GSM8K(path, option)

    def __len__(self):
        return len(self.examples)

    def __iter__(self):
        return self.examples.__iter__()

    @staticmethod
    def extract_confidence_estimate(text_with_estimate):
        r = re.compile(r'\b[0-9]\b|\b10\b')
        numbers = r.findall(text_with_estimate)
        estimate = numbers[0]
        return estimate




class MATH(Dataset):
    def __init__(self, path, option):
        print(f'Preparing data set from {path}')
        with open(path) as f:
            examples = f.read().splitlines()
        self.examples = []
        for example in examples:
            example = json.loads(example)
            query = example['problem']
            solution = example['solution']
            answer = MATH.normalize_final_answer(example['answer'])
            self.examples.append(Example(query, answer, solution))

    @staticmethod
    def normalize_final_answer(answer):
        SUBSTITUTIONS = [
            ("an ", ""),
            ("a ", ""),
            (".$", "$"),
            ("\\$", ""),
            (r"\ ", ""),
            (" ", ""),
            ("mbox", "text"),
            (",\\text{and}", ","),
            ("\\text{and}", ","),
            ("\\text{m}", "\\text{}"),
        ]

        REMOVED_EXPRESSIONS = [
            "square",
            "ways",
            "integers",
            "dollars",
            "mph",
            "inches",
            "ft",
            "hours",
            "km",
            "units",
            "\\ldots",
            "sue",
            "points",
            "feet",
            "minutes",
            "digits",
            "cents",
            "degrees",
            "cm",
            "gm",
            "pounds",
            "meters",
            "meals",
            "edges",
            "students",
            "childrentickets",
            "multiples",
            "\\text{s}",
            "\\text{.}",
            "\\text{\ns}",
            "\\text{}^2",
            "\\text{}^3",
            "\\text{\n}",
            "\\text{}",
            r"\mathrm{th}",
            r"^\circ",
            r"^{\circ}",
            r"\;",
            r",\!",
            "{,}",
            '"',
            "\\dots",
        ]

        for before, after in SUBSTITUTIONS:
            answer = answer.replace(before, after)
        for expr in REMOVED_EXPRESSIONS:
            answer = answer.replace(expr, "")

        # Extract answer that is in LaTeX math, is bold,
        # is surrounded by a box, etc.
        answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", answer)
        answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", answer)
        answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", answer)
        answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", answer)
        answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", answer)

        # Normalize shorthand TeX:
        #  \fracab -> \frac{a}{b}
        #  \frac{abc}{bef} -> \frac{abc}{bef}
        #  \fracabc -> \frac{a}{b}c
        #  \sqrta -> \sqrt{a}
        #  \sqrtab -> sqrt{a}b
        answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", answer)
        answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", answer)
        answer = answer.replace("$", "")

        # Normalize 100,000 -> 100000
        if answer.replace(",", "").isdigit():
            answer = answer.replace(",", "")

        return answer

    @staticmethod
    def get_unnormalized_answer(text):
        INVALID_ANSWER = f"[invalidanswer_{randint(1000,9999)}]"
        end_seq = "I hope it is correct."
        text += end_seq
        r = re.compile(r"Final Answer: The final answer is(.*?). I hope it is correct.")
        answers = r.findall(text)
        if answers:
            return answers[-1]
        return INVALID_ANSWER

    @staticmethod
    def extract_final_answer(solution):
        answer = MATH.get_unnormalized_answer(solution)
        answer = MATH.normalize_final_answer(answer)
        return answer


class GSM8K(Dataset):
    def __init__(self, path, option):
        print(f'Preparing data set from {path}')
        with open(path) as f:
            examples = f.read().splitlines()
        self.examples = []
        for example in examples:
            example = json.loads(example)
            query = example['question']
            answer = self.extract_final_answer(example['answer'])
            self.examples.append(Example(query, answer))

    @staticmethod
    def extract_final_answer(solution):
        r = re.compile('-?[0-9]*[.,]*[0-9]+')
        numbers = r.findall(solution)
        if numbers:
            ans = numbers[-1]
            ans = ans.replace(',', '')
            try:
                if float(ans) == int(float(ans)):
                    ans = ans.split('.')[0]
            except:
                # float(ans) failed
                pass
        else:
            ans = ''
        return ans


class ARC(Dataset):
    def __init__(self, path, option):
        print(f'Preparing data set from {path}')
        examples_paths = glob(path + '/**/*.json', recursive=True)
        assert len(examples_paths), 'Could not find any data!'
        self.examples = []
        for ep in examples_paths:
            with open(ep) as f:
                task = json.load(f)
            train_grids = task['train']
            test_grids = task['test']
            train_grids_str = '\n'.join(
                [f"Input:\n{str(np.array(e['input']))}\n\nOutput:\n{str(np.array(e['output']))}\n"
                           for e in train_grids])
            test_input_grid_str = \
                 f"Test input:\n{str(np.array(test_grids[0]['input']))}\n"
            test_output_grid_str = \
                 f"{str(np.array(test_grids[0]['output']))}"
            if option == 'no-test-grid':
                train_test_grids = train_grids_str
            else:
                train_test_grids = train_grids_str + '\n' + test_input_grid_str
            self.examples.append(Example(
                train_test_grids,
                test_output_grid_str,
                info=ep
            ))

    @staticmethod
    def extract_final_answer(output):
        try:
            output = '[[' + output.split('[[')[1]
            output = output.split(']]')[0] + ']]'
            return output
        except:
            return ''


class Math(Dataset):
    def __init__(self, path, option):
        print(f'Preparing data set from {path}')
        self.examples = []
        with open(path) as f:
            examples = f.read().splitlines()
        for e in examples:
            query = e.split('@')[0]
            answer = e.split('@')[1]
            self.examples.append(Example(query, answer))

    @staticmethod
    def extract_final_answer(solution):
        r = re.compile('-?[0-9]*[.,]*[0-9]+')
        numbers = r.findall(solution)
        if numbers:
            ans = numbers[-1]
            ans = ans.replace(',', '')
        else:
            ans = ''
        return ans

# test
if __name__ == '__main__':
    import sys
    path = sys.argv[1]
    try:
        option = sys.argv[2]
    except:
        option = None
    d = Dataset.load(path, option)
    for e in d:
        print()
        print('PROBLEM:\n', e.query, sep='')
        print()
        print('FINAL ANSWER:\n', e.answer, sep='')
    print('\nNumber of problems:', len(d))

    if 'MATH' in path:
        for e in d:
            extracted_answer = MATH.extract_final_answer(e.solution)
            provided_answer = e.answer
            provided_answer_normalized = MATH.normalize_final_answer(provided_answer)
            print('PROVIDED ANSWER:')
            print(provided_answer)
            print('PROVIDED ANSWER NORMALIZED:')
            print(provided_answer_normalized)
            print()
