import re
import os
import json
import torch
import random
import argparse
import pandas as pd
from prompts import *
from tqdm import trange
import concurrent.futures

def convert_to_float(frac_str):
    try:
        return float(frac_str)
    except ValueError:
        num, denom = frac_str.split('/')
        try:
            leading, num = num.split(' ')
            whole = float(leading)
        except ValueError:
            whole = 0
        frac = float(num) / float(denom)
        return whole - frac if whole < 0 else whole + frac


class Eval:

    def __init__(self, dataset, model_path_or_name, eval_type="dlm_normal",
                 shot=0, max_input_length=None,
                 cards=['0'], process_idx=0, total_processes=1, ppl_model=None,
                 api_addr=None, api_key=None,
                 ):
        path_dict = {
            "gsm8k": "openai/gsm8k",
            "AIME_2024": "Maxwell-Jia/AIME_2024",
            "MATH": "Maxwell-Jia/MATH",
            "MATH-500": "HuggingFaceH4/MATH-500",
            "gpqa": "Idavidrein/gpqa",
            "Countdown-3": "countdown-3",
            "Sudoku": "sudoku",
            "5_digit_multiplication": "5_digit_multiplication",
        }
        assert dataset in path_dict.keys()
        self.match_pattern = {
            "gsm8k": [r"The final answer is \$?((?:-?[0-9.,]{2,})|(?:-?[0-9]+))", r"((?:-?[0-9.,]{2,})|(?:-?[0-9]+))"],
            "AIME_2024": [r"boxed{(.*)}","framebox{(.*)}"],
            "MATH": [r"boxed{(.*)}","framebox{(.*)}"],
            "MATH-500": [r"boxed{(.*)}","framebox{(.*)}"],
            "gpqa": [r"(?<=The answer is )(.*)(?=.)",r"(\([A-Z]\))"],
            "Countdown-3": [r"([0-9][0-9\,\s\+\-\*\=\/]*)"],
            "Sudoku": [r"((?:[0-9]\ ?){16}?)"],
            "5_digit_multiplication": [r"[-+]?[,.\d]+"],
        }
        self.process_idx = process_idx
        self.total_processes = total_processes

        if eval_type in ["dlm_normal", "dlm_test_time_scaling"]:
            num_available_cards = len(cards)
            a, b = divmod(num_available_cards, total_processes)
            os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cards[process_idx*a: (process_idx+1)*a])

        self.dataset = dataset
        self.dataset_path = os.path.join(os.path.dirname(__file__), "datasets", path_dict[self.dataset])
        self.eval_type = eval_type
        if eval_type == "dlm_normal":
            from chat_completion import ModelCompletion
            self.model_completion = ModelCompletion(model_path_or_name, max_input_length=max_input_length)
        elif eval_type == "dlm_test_time_scaling":
            from test_time_scaling import TestTimeScaling
            self.model_completion = TestTimeScaling(model_path_or_name, max_input_length=max_input_length, model_to_calculate_ppl=ppl_model)
        elif eval_type == "arlm_api":
            from chat_completion_openai_api import ApiCompletion
            self.model_completion = ApiCompletion(model_name=model_path_or_name, base_url=api_addr, api_key=api_key)
        else:
            raise Exception(f"Unsupported evaluating type {eval_type}!")

        self.shot = shot

    def postprocess_countdown(self, response):
        return response.replace('\n', ',').replace(' ', '').strip(',')

    def postprocess_sudoku(self, response):
        return response.replace(' ', '')

    def collect_answer_from_response(self, responses):
        map_postprocess = {
            "Countdown-3": self.postprocess_countdown,
            "Sudoku": self.postprocess_sudoku,
        }

        regex_list = self.match_pattern[self.dataset]
        if not type(responses) == list:
            responses = [responses]

        tmp_res = []
        for response in responses:
            _res = ""
            try:
                for regex in regex_list:
                    _res = re.findall(regex, response, flags=re.MULTILINE)
                    _res = _res[-1] if _res and len(_res)>0 else ""
                    if _res != "":
                        break
            except Exception:
                pass
            _res = _res.strip('.')
            if self.dataset in map_postprocess.keys():
                _res = map_postprocess[self.dataset](_res)
            tmp_res.append(_res)

        return tmp_res

    def load_parquet(self, file_path_list):
        all_data = []
        for file_path in file_path_list:
            _data = pd.read_parquet(file_path)
            for row in _data.values:
                all_data.append({k:v for k,v in zip(_data.columns.tolist(), row)})

        return all_data

    def load_jsonl(self, file_path_list):
        all_data = []
        for file_path in file_path_list:
            with open(file_path, 'r', encoding='utf-8') as f:
                all_data += [json.loads(line) for line in f.readlines()]

        return all_data

    def load_csv(self, file_path_list):
        all_data = []
        for file_path in file_path_list:
            _data = pd.read_csv(file_path)
            for row in _data.values:
                all_data.append({k:v for k,v in zip(_data.columns.tolist(), row)})

        return all_data

    def prepare_gsm8k(self):
        main_path = os.path.join(self.dataset_path, "main", "test-00000-of-00001.parquet")
        socratic_path = os.path.join(self.dataset_path, "socratic", "test-00000-of-00001.parquet")

        template = prompts[self.shot][self.dataset]
        assert template!=None, f"{self.shot} shot(s) is not supported for dataset {self.dataset}"

        res = []
        all_data = self.load_parquet([main_path])

        for i in range(len(all_data)):
            all_data[i]["prompt"] = [{"role":"user", "content":template.replace("{{question}}", all_data[i]["question"])}]
            all_data[i]["ground_truth"] = all_data[i]["answer"].split('####')[-1].strip()

        return all_data

    def prepare_aime_2024(self):
        file_path = os.path.join(self.dataset_path, "aime_2024_problems.parquet")

        template = prompts[self.shot][self.dataset]
        assert template!=None, f"{self.shot} shot(s) is not supported for dataset {self.dataset}"

        all_data = self.load_parquet([file_path])

        for i in range(len(all_data)):
            all_data[i]["prompt"] = [{"role":"user", "content":template.replace("{{question}}", all_data[i]["Problem"])}]
            all_data[i]["ground_truth"] = str(all_data[i]["Answer"])

        return all_data

    def prepare_math(self):
        file_path = os.path.join(self.dataset_path, "data", "test", "0000.parquet")

        assert self.shot in prompts.keys() and self.dataset in prompts[self.shot].keys(), f"{self.shot} shot(s) is not supported for dataset {self.dataset}"
        template = prompts[self.shot][self.dataset]
        assert template!=None, f"{self.shot} shot(s) is not supported for dataset {self.dataset}"

        all_data = self.load_parquet([file_path])

        for i in range(len(all_data)):
            all_data[i]["prompt"] = [{"role":"user", "content":template.replace("{{question}}", all_data[i]["problem"])}]
            all_data[i]["ground_truth"] = self.collect_answer_from_response([r"boxed{(.*)}"], all_data[i]["solution"])[0]

        return all_data

    def prepare_math_500(self):
        file_path = os.path.join(self.dataset_path, "test.jsonl")

        assert self.shot in prompts.keys() and self.dataset in prompts[self.shot].keys(), f"{self.shot} shot(s) is not supported for dataset {self.dataset}"
        template = prompts[self.shot][self.dataset]
        assert template!=None, f"{self.shot} shot(s) is not supported for dataset {self.dataset}"

        all_data = self.load_jsonl([file_path])

        for i in range(len(all_data)):
            all_data[i]["prompt"] = [{"role":"user", "content":template.replace("{{question}}", all_data[i]["problem"])}]
            all_data[i]["ground_truth"] = all_data[i]["answer"]

        return all_data

    def prepare_gpqa(self):
        file_path = os.path.join(self.dataset_path, "gpqa_diamond.csv")

        assert self.shot in prompts.keys() and self.dataset in prompts[self.shot].keys(), \
            f"{self.shot} shot(s) is not supported for dataset {self.dataset}"
        template = prompts[self.shot][self.dataset]
        assert template!=None, f"{self.shot} shot(s) is not supported for dataset {self.dataset}"

        all_data = self.load_csv([file_path])
        # https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gpqa/cot_n_shot/utils.py
        def preprocess(text):
            if text is None:
                return " "
            text = text.strip()
            text = text.replace(" [title]", ". ")
            text = re.sub("\\[.*?\\]", "", text)
            text = text.replace("  ", " ")
            return text

        def _process_doc(doc):
            choices = [
                preprocess(doc["Incorrect Answer 1"]),
                preprocess(doc["Incorrect Answer 2"]),
                preprocess(doc["Incorrect Answer 3"]),
                preprocess(doc["Correct Answer"]),
            ]

            random.shuffle(choices)
            correct_answer_index = choices.index(preprocess(doc["Correct Answer"]))

            out_doc = {
                "choice1": choices[0],
                "choice2": choices[1],
                "choice3": choices[2],
                "choice4": choices[3],
                "choices": [choices[0], choices[1], choices[2], choices[3]],
                "answer": f"({chr(65 + correct_answer_index)})",
            }
            return out_doc

        for i in range(len(all_data)):
            out_doc = _process_doc(all_data[i])
            all_data[i]["prompt"] = [{
                "role":"user",
                "content":template.replace("{{question}}", all_data[i]["Question"])\
                                  .replace("{{choice1}}", out_doc["choice1"])\
                                  .replace("{{choice2}}", out_doc["choice2"])\
                                  .replace("{{choice3}}", out_doc["choice3"])\
                                  .replace("{{choice4}}", out_doc["choice4"])\
            }]
            all_data[i]["ground_truth"] = out_doc["answer"]

        return all_data


    def prepare_countdown_3(self):
        file_path = os.path.join(self.dataset_path, "cd3_test.jsonl")

        assert self.shot in prompts.keys() and self.dataset in prompts[self.shot].keys(), f"{self.shot} shot(s) is not supported for dataset {self.dataset}"
        template = prompts[self.shot][self.dataset]
        assert template!=None, f"{self.shot} shot(s) is not supported for dataset {self.dataset}"

        all_data = self.load_jsonl([file_path])

        for i in range(len(all_data)):
            numbers = all_data[i]["input"].split(',')
            target = numbers[-1]
            numbers = numbers[:-1]
            all_data[i]["prompt"] = [{"role":"user", "content":template.replace("{{numbers}}", " ".join(numbers)).replace("{{target}}", target)}]
            all_data[i]["ground_truth"] = all_data[i]["output"]

        return all_data


    def prepare_sudoku(self):
        file_path = os.path.join(self.dataset_path, "4x4_test_sudoku.csv")

        assert self.shot in prompts.keys() and self.dataset in prompts[self.shot].keys(), f"{self.shot} shot(s) is not supported for dataset {self.dataset}"
        template = prompts[self.shot][self.dataset]
        assert template!=None, f"{self.shot} shot(s) is not supported for dataset {self.dataset}"

        all_data = self.load_csv([file_path])

        for i in range(len(all_data)):
            for k in all_data[i].keys():
                all_data[i][k] = str(all_data[i][k])
            all_data[i]["prompt"] = [{"role":"user", "content":template.replace("{{puzzle}}", ' '.join(list(all_data[i]["Puzzle"])))}]
            all_data[i]["ground_truth"] = all_data[i]["Solution"]

        return all_data


    def prepare_5_digit_multiplication(self):
        file_path = os.path.join(self.dataset_path, "arithmetic_5_digit_multiplication.json")

        assert self.shot in prompts.keys() and self.dataset in prompts[self.shot].keys(), f"{self.shot} shot(s) is not supported for dataset {self.dataset}"
        template = prompts[self.shot][self.dataset]
        assert template!=None, f"{self.shot} shot(s) is not supported for dataset {self.dataset}"

        with open(file_path, 'r', encoding='utf-8') as f:
            all_data = json.load(f)

        all_data = all_data["examples"]

        for i in range(len(all_data)):
            all_data[i]["prompt"] = [{"role":"user", "content":template.replace("{{question}}", all_data[i]["input"])}]
            all_data[i]["ground_truth"] = all_data[i]["target"]

        return all_data


    def eval(self, res_jsonl, batch_size=8, **kwargs):
        prepare_map = {
            "gsm8k": self.prepare_gsm8k,
            "AIME_2024": self.prepare_aime_2024,
            "MATH": self.prepare_math,
            "MATH-500": self.prepare_math_500,
            "gpqa": self.prepare_gpqa,
            "Countdown-3": self.prepare_countdown_3,
            "Sudoku": self.prepare_sudoku,
            "5_digit_multiplication": self.prepare_5_digit_multiplication,
        }
        assert self.dataset in prepare_map.keys(), f"Not supported for dataset {self.dataset}"
        all_data = prepare_map[self.dataset]()

        res = []

        a, b = divmod(len(all_data), self.total_processes)
        if b > 0: a += 1
        all_data = all_data[self.process_idx*a : (self.process_idx+1)*a]

        file_mode = "w" if not os.path.exists(res_jsonl) else "r+"
        lines = []
        with open(res_jsonl, file_mode, encoding='utf-8') as f:
            if file_mode == 'r+':
                lines = f.readlines()
                res += [json.loads(line) for line in lines]
            for i in trange(len(lines), len(all_data), batch_size):
                batch = all_data[i:i+batch_size]
                completions, out = self.model_completion.complete(
                    [it["prompt"] for it in batch],
                    ground_truth=[it["ground_truth"] for it in batch],
                    **kwargs
                )

                for b, c, o in zip(batch, completions, out):
                    b["model_response"] = c
                    b["model_answer"] = self.collect_answer_from_response(c)
                    if self.eval_type == "dlm_test_time_scaling":
                        b["trajectory"] = o
                    f.write(json.dumps(b, ensure_ascii=False)+'\n')
                    res.append(b)

        return res


if __name__ == "__main__":
    '''
    Usage:

    To run the baseline with 8 cards:

    $ python eval_datasets.py \
               -m /path/to/model -d gsm8k -r /path/to/res.jsonl \
               --batch-size 16 --steps 128 --gen-length 256 --block-length 8 \
               --temperature 0.0 --shot 0 \
               --cards 0,1,2,3,4,5,6,7 --cards-per-model 2

    To run test time scaling with 8 cards:

    $ python eval_datasets.py \
               -m /path/to/model -d gsm8k -r /path/to/res.jsonl \
               --batch-size 16 --steps 128 --gen-length 256 --block-length 8 \
               --temperature 0.5 --shot 0 \
               --test-time-scaling -n 4 --topk 2 \
               --search-every-steps=64 --ppl-model /path/to/model/like/GPT2 \
               --cards 0,1,2,3,4,5,6,7 --cards-per-model 2
    '''

    dataset_map = {
        "gsm8k": "gsm8k",
        "aime2024": "AIME_2024",
        "math": "MATH",
        "math500": "MATH-500",
        "gpqa": "gpqa",
        "countdown3": "Countdown-3",
        "sudoku": "Sudoku",
        "5digmulti": "5_digit_multiplication",
    }

    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model")
    parser.add_argument("-d", "--dataset", choices=list(dataset_map.keys()))
    parser.add_argument("-r", "--res-jsonl")
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--steps", type=int, default=128)
    parser.add_argument("--gen-length", type=int, default=128)
    parser.add_argument("--block-length", type=int, default=32)
    parser.add_argument("--temperature", type=float, default=0.1)
    parser.add_argument("--shot", type=int, choices=[0, 4, 5])
    parser.add_argument("--cards", type=str)
    parser.add_argument("--cards-per-model", type=int, default=1)
    parser.add_argument("--max-input-length", type=int, default=None)
    parser.add_argument("--without-think", action="store_true")
    ### test-time scaling
    parser.add_argument("--test-time-scaling", action="store_true")
    parser.add_argument("--reward-batch-size", type=int, default=4)
    parser.add_argument("--search-every-steps", type=int, default=16)
    parser.add_argument("-n", type=int, default=8)
    parser.add_argument("--topk", type=int, default=4)
    parser.add_argument("--ppl-model", type=str, default=None)
    parser.add_argument("--sample_ratio_calculating_correlation_inside_response", type=str, default="1")
    parser.add_argument("--reward_list", type=str, default="correlation,format,accuracy,ppl")
    ### openai api
    parser.add_argument("--api", action="store_true")
    parser.add_argument("--api-addr", type=str, default="http://127.0.0.1:8000/v1")
    parser.add_argument("--api-key", type=str, default="token-abc123")
    parser.add_argument("--num-process", type=int, default=4)
    args = parser.parse_args()

    args.sample_ratio_calculating_correlation_inside_response = convert_to_float(
        args.sample_ratio_calculating_correlation_inside_response
    )
    args.reward_list = [i.strip() for i in args.reward_list.split(",")]

    print(f"Evaluating {args.dataset} ...")

    assert not (args.test_time_scaling and args.api), \
        "Using api to do test time scaling is not supported yet."

    if not args.api:
        available_cards = [i.strip() for i in args.cards.split(',')]
        num_available_cards = len(available_cards)
        assert num_available_cards > 0, "No cards avaiable currently."
        a, b = divmod(num_available_cards, args.cards_per_model)
        if b > 0:
            print(f"{b} cards will not be used...")
        process_num = a
        files = [f"{args.res_jsonl}.{i}" for i in range(a)]
    else:
        process_num = args.num_process
        a = process_num
        files = [f"{args.res_jsonl}.{i}" for i in range(process_num)]

    if not args.without_think:
        prompts = prompt_with_think

    def run(i):
        if args.test_time_scaling:
            Eval(
                dataset_map[args.dataset], args.model, eval_type="dlm_test_time_scaling",
                shot=args.shot, cards=available_cards, max_input_length=args.max_input_length,
                ppl_model=args.ppl_model, process_idx=i, total_processes=a
            ).eval(
                files[i], batch_size=args.batch_size,
                steps=args.steps, block_length=args.block_length,
                temperature=args.temperature,
                gen_length=args.gen_length,
                reward_batch_size=args.reward_batch_size,
                search_every_steps_n=args.search_every_steps, n=args.n, topk=args.topk,
                sample_ratio_calculating_correlation_inside_response=args.sample_ratio_calculating_correlation_inside_response,
                reward_list=args.reward_list,
            )
        elif args.api:
            Eval(
                dataset_map[args.dataset], args.model, eval_type="arlm_api",
                shot=args.shot, max_input_length=None,
                process_idx=i, total_processes=a, api_addr=args.api_addr, api_key=args.api_key
            ).eval(
                files[i], batch_size=args.batch_size,
                gen_length=args.gen_length,
                temperature=args.temperature
            )
        else:
            Eval(
                dataset_map[args.dataset], args.model, eval_type="dlm_normal",
                shot=args.shot, max_input_length=args.max_input_length,
                cards=available_cards, process_idx=i, total_processes=a
            ).eval(
                files[i], batch_size=args.batch_size,
                gen_length=args.gen_length,
                steps=args.steps, block_length=args.block_length,
                temperature=args.temperature
            )

    with concurrent.futures.ProcessPoolExecutor(process_num) as executor:
        futures = [executor.submit(run, i) for i in range(a)]
        for future in futures:
            result = future.result()

    res = []
    for file in files:
        with open(file, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                res.append(json.loads(line))

    with open(args.res_jsonl, 'w', encoding='utf-8') as f:
        for r in res:
            f.write(json.dumps(r, ensure_ascii=False)+'\n')

    print(f"Result saved at file {args.res_jsonl}")
