import gzip, json
from human_eval.evaluation import evaluate_functional_correctness

class gsm8k():

    def __init__(self, ):
        pass

    def read(self, ):
        import jsonlines
        gsm8k_ins = []
        gsm8k_answers = []
        problem_prompt = (
            "Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
        )
        with open('datasets/gsm8k_test.jsonl', "r+", encoding="utf8") as f:
            for idx, item in enumerate(jsonlines.Reader(f)):
                temp_instr = problem_prompt.format(instruction=item["question"])
                gsm8k_ins.append(temp_instr)
                temp_ans = item['answer'].split('#### ')[1]
                temp_ans = int(temp_ans.replace(',', ''))
                gsm8k_answers.append(temp_ans)
        return {
            'prompts': gsm8k_ins,
            'labels': gsm8k_answers,
        }
    
    def test(self, output):
        import re
        from fraction import Fraction

        def is_number(s):
            try:
                float(s)
                return True
            except ValueError:
                pass
            try:
                import unicodedata
                unicodedata.numeric(s)
                return True
            except (TypeError, ValueError):
                pass
            return False

        def extract_answer_number(completion):
            text = completion.split('The answer is: ')
            if len(text) > 1:
                extract_ans = text[-1].strip()
                match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
                if match:
                    if '/' in match.group():
                        denominator = match.group().split('/')[1]
                        numerator = match.group().split('/')[0]
                        if is_number(denominator) == True and is_number(numerator) == True:
                            if denominator == '0':
                                return round(float(numerator.replace(',', '')))
                            else:
                                frac = Fraction(match.group().replace(',', ''))
                                num_numerator = frac.numerator
                                num_denominator = frac.denominator
                                return round(float(num_numerator / num_denominator))
                        else:
                            return None
                    else:
                        if float(match.group().replace(',', '')) == float('inf'):
                            return None
                        return round(float(match.group().replace(',', '')))
                else:
                    return None
            else:
                return None

        results = []
        for idx, (completion, prompt_answer) in enumerate(zip(output['outputs'], output['labels'])):
            y_pred = extract_answer_number(completion)
            if y_pred != None:
                results.append(float(y_pred) == float(prompt_answer))
            else:
                results.append(False)
        accuracy = sum(results) / len(results)
        return {
            'accuracy': accuracy
        }

class MATH():

    def __init__(self, ):
        pass

    def read(self, ):

        def remove_boxed(s):
            left = "\\boxed{"
            try:
                assert s[:len(left)] == left
                assert s[-1] == "}"
                return s[len(left):-1]
            except:
                return None

        def last_boxed_only_string(string):
            idx = string.rfind("\\boxed")
            if idx < 0:
                idx = string.rfind("\\fbox")
                if idx < 0:
                    return None

            i = idx
            right_brace_idx = None
            num_left_braces_open = 0
            while i < len(string):
                if string[i] == "{":
                    num_left_braces_open += 1
                if string[i] == "}":
                    num_left_braces_open -= 1
                    if num_left_braces_open == 0:
                        right_brace_idx = i
                        break
                i += 1

            if right_brace_idx == None:
                retval = None
            else:
                retval = string[idx:right_brace_idx + 1]

            return retval

        import jsonlines
        hendrycks_math_ins = []
        hendrycks_math_answers = []
        problem_prompt = (
            "Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
        )
        with open('datasets/MATH_test.jsonl', "r+", encoding="utf8") as f:
            for idx, item in enumerate(jsonlines.Reader(f)):
                temp_instr = problem_prompt.format(instruction=item["instruction"])
                hendrycks_math_ins.append(temp_instr)
                solution = item['output']
                temp_ans = remove_boxed(last_boxed_only_string(solution))
                hendrycks_math_answers.append(temp_ans)
        return {
            'prompts': hendrycks_math_ins,
            'labels': hendrycks_math_answers,
        }
    
    def test(self, output):

        def fix_fracs(string):
            substrs = string.split("\\frac")
            new_str = substrs[0]
            if len(substrs) > 1:
                substrs = substrs[1:]
                for substr in substrs:
                    new_str += "\\frac"
                    if substr[0] == "{":
                        new_str += substr
                    else:
                        try:
                            assert len(substr) >= 2
                        except AssertionError:
                            return string
                        a = substr[0]
                        b = substr[1]
                        if b != "{":
                            if len(substr) > 2:
                                post_substr = substr[2:]
                                new_str += "{" + a + "}{" + b + "}" + post_substr
                            else:
                                new_str += "{" + a + "}{" + b + "}"
                        else:
                            if len(substr) > 2:
                                post_substr = substr[2:]
                                new_str += "{" + a + "}" + b + post_substr
                            else:
                                new_str += "{" + a + "}" + b
            string = new_str
            return string

        def fix_a_slash_b(string):
            if len(string.split("/")) != 2:
                return string
            a = string.split("/")[0]
            b = string.split("/")[1]
            try:
                a = int(a)
                b = int(b)
                assert string == "{}/{}".format(a, b)
                new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
                return new_string
            except AssertionError:
                return string

        def remove_right_units(string):
            # "\\text{ " only ever occurs (at least in the val set) when describing units
            if "\\text{ " in string:
                splits = string.split("\\text{ ")
                assert len(splits) == 2
                return splits[0]
            else:
                return string

        def fix_sqrt(string):
            if "\\sqrt" not in string:
                return string
            splits = string.split("\\sqrt")
            new_string = splits[0]
            for split in splits[1:]:
                if split[0] != "{":
                    a = split[0]
                    new_substr = "\\sqrt{" + a + "}" + split[1:]
                else:
                    new_substr = "\\sqrt" + split
                new_string += new_substr
            return new_string

        def strip_string(string):
            # linebreaks
            string = string.replace("\n", "")

            # remove inverse spaces
            string = string.replace("\\!", "")

            # replace \\ with \
            string = string.replace("\\\\", "\\")

            # replace tfrac and dfrac with frac
            string = string.replace("tfrac", "frac")
            string = string.replace("dfrac", "frac")

            # remove \left and \right
            string = string.replace("\\left", "")
            string = string.replace("\\right", "")

            # Remove circ (degrees)
            string = string.replace("^{\\circ}", "")
            string = string.replace("^\\circ", "")

            # remove dollar signs
            string = string.replace("\\$", "")

            # remove units (on the right)
            string = remove_right_units(string)

            # remove percentage
            string = string.replace("\\%", "")
            string = string.replace("\%", "")  # noqa: W605

            # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
            string = string.replace(" .", " 0.")
            string = string.replace("{.", "{0.")
            # if empty, return empty string
            if len(string) == 0:
                return string
            if string[0] == ".":
                string = "0" + string

            # to consider: get rid of e.g. "k = " or "q = " at beginning
            if len(string.split("=")) == 2:
                if len(string.split("=")[0]) <= 2:
                    string = string.split("=")[1]

            # fix sqrt3 --> sqrt{3}
            string = fix_sqrt(string)

            # remove spaces
            string = string.replace(" ", "")

            # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
            string = fix_fracs(string)

            # manually change 0.5 --> \frac{1}{2}
            if string == "0.5":
                string = "\\frac{1}{2}"

            # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
            string = fix_a_slash_b(string)

            return string

        def is_equiv(str1, str2, verbose=False):
            if str1 is None and str2 is None:
                print("WARNING: Both None")
                return True
            if str1 is None or str2 is None:
                return False

            try:
                ss1 = strip_string(str1)
                ss2 = strip_string(str2)
                # pdb.set_trace()
                if verbose:
                    print(ss1, ss2)
                return ss1 == ss2
            except Exception:
                return str1 == str2

        def process_results( completion, answer, invalid_outputs):
            split_ans = completion.split('The answer is: ')
            if len(split_ans) > 1:
                ans = split_ans[-1]
                extract_ans_temp = ans.split('.\n')[0]
                extract_ans_temp = extract_ans_temp.strip()
                if len(extract_ans_temp) > 0 and extract_ans_temp[-1] == '.':
                    extract_ans = extract_ans_temp[0:-1]
                else:
                    extract_ans = extract_ans_temp
                extract_ans = extract_ans.strip()
                if is_equiv(extract_ans, answer):
                    return True
                else:
                    return False
            else:
                temp = {'output': completion, 'answer': answer}
                invalid_outputs.append(temp)
                return False
        
        results,invalid_outputs = [],[]
        for idx, (completion, prompt_answer) in enumerate(zip(output['outputs'], output['labels'])):
            results.append(
                 process_results(completion, prompt_answer, invalid_outputs)
            )
        accuracy = sum(results) / len(results)
        return {
            'accuracy': accuracy
        }

class AlpacaEval():

    def __init__(self,):
        pass

    def read(self,):
        import datasets
        eval_set = datasets.load_dataset(path="tatsu-lab/alpaca_eval", name="alpaca_eval")["eval"]
        instructions = []
        reference_outputs = []
        for example in eval_set:
            instructions.append(example["instruction"])
            reference_outputs.append(example)
        return instructions,reference_outputs

    def test(self, output, label):
        pass
        # give to GPT4
class mbpp():

    def __init__(self,):
        pass

    def read(self,):
        import jsonlines

        problems = {}
        with jsonlines.open('datasets/mbpp.test.jsonl', "r") as fin:
            for obj in fin:
                problems[obj["task_id"]] = obj
        task_ids = sorted(problems.keys())

        prompts = []
        for task_id in task_ids:
            prompt = f"\n{problems[task_id]['text']}\nTest examples:"
            if task_id == 493:
                # The test examples are too long, we choose to only include the function name.
                test_example = problems[task_id]['test_list'][0]
                prompt += f"\ncalculate_polygons(startx, starty, endx, endy, radius)"
            else:
                for test_example in problems[task_id]['test_list']:
                    prompt += f"\n{test_example}"
            prompts.append(prompt)

class human_eval():

    def __init__(self,):
        pass

    def read(self,):

        def stream_jsonl():
            path='datasets/HumanEval.jsonl.gz'
            with open(path, "rb") as gzfp:
                with gzip.open(gzfp, 'rt') as fp:
                    for line in fp:
                        if any(not x.isspace() for x in line):
                            yield json.loads(line)

        def process(prompt):
            prompt.replace('    ', '\t')
            return (
                "Below is an instruction that describes a task. "
                "Write a response that appropriately completes the request.\n\n"
                f"### Instruction:\n{prompt}\n\n### Response:"
            )

        problems = {task["task_id"]: task for task in stream_jsonl()}
        task_ids = sorted(problems.keys())
        prompts = [process(problems[task_id]['prompt']) for task_id in task_ids]
        return {
            'task_ids': task_ids,
            'prompts': prompts
        }

    def test(self, output):

        def process_code(completion):
            completion = completion.split("### Response:")[-1].replace('\t', '    ')
            completion = completion.replace("\r", "").strip()
            if '```python' in completion:
                def_line = completion.index('```python')
                completion = completion[def_line:].strip()
                completion = completion.replace('```python', '')
                try:
                    next_line = completion.index('```')
                    completion = completion[:next_line].strip()
                except:
                    pass
            if "__name__ == \"__main__\"" in completion:
                try:
                    next_line = completion.index('if __name__ == "__main__":')
                    completion = completion[:next_line].strip()
                except:
                    pass
            if "# Example usage" in completion:
                next_line = completion.index('# Example usage')
                completion = completion[:next_line].strip()
            # the following codes are used to deal with the outputs of code-alpaca
            if "The solution is:" in completion:
                def_line = completion.index("The solution is:")
                completion = completion[def_line:].strip()
                completion = completion.replace('The solution is:', '')
                try:
                    next_line = completion.index('\n\nThe answer is:')
                    completion = completion[:next_line].strip()
                except:
                    completion = completion.strip()
            if "The answer is:" in completion:
                def_line = completion.index("The answer is:")
                completion = completion[def_line:].strip()
                completion = completion.replace('The answer is:', '')
                try:
                    next_line = completion.index('\n\nThe answer is:')
                    completion = completion[:next_line].strip()
                except:
                    completion = completion.strip()
            return completion

        outputs = [
            {
            'task_id': t,
            'completion': process_code(o),
            'all_code': o.replace('\t', '    '),
            } for o,t in zip(
                output['outputs'],
                output['task_ids']
            )
        ]

        file_path=output['path']
        with open(file_path, 'wb') as fp:
            for x in outputs:
                fp.write((json.dumps(x) + "\n").encode('utf-8'))
        # pass@1
        # or in terminal `evaluate_functional_correctness /home/xx/twin-merge/MergeLM/save_gen_codes_results/WizardLM_WizardCoder-Python-7B-V1.0.jsonl`
        results = evaluate_functional_correctness(file_path, k=[1], n_workers=4, timeout=3, problem_file='datasets/HumanEval.jsonl.gz')
        return results