import sys
import os
import json
import re
import random
from math import isclose
from typing import Union, Any
import numpy as np
from pathlib import Path 
file = Path(__file__).resolve()
parent, root = file.parent, file.parents[1]
sys.path.append(str(root))
from tqdm import tqdm
import argparse
from decoding_algorithm import ContrastiveDecoding
from utils.dataset_loader import DatasetLoader, load_prompt


def parse_api_result(result):
    if not result:
        return None
    to_return = [g.message.content for g in result.choices]
    return to_return

def round_with_error(x):
    return round(x * 1e5) / 1e5

def floatify_ans(ans):
    """gsm8k"""
    if ans is None:
        return None
    elif type(ans) == dict:
        ans = list(ans.values())[0]
    elif type(ans) == bool:
        ans = ans
    elif type(ans) in [list, tuple]:
        if not ans:
            return None
        else:
            try:
                ans = float(ans[0])
            except Exception:
                ans = str(ans[0])
    else:
        try:
            ans = float(ans)
            ans = round_with_error(ans)
        except Exception:
            ans = str(ans)
    return ans

def find_last_number_in_last_sentence(text):
    sentences = re.split(r'[.!?]', text)
    sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
    if sentences:
        last_sentence = sentences[-1]
        numbers = re.findall(r'\d+', last_sentence)
        if numbers:
            return numbers[-1]
    return None

def find_first_number_in_sentence(text):
    match = re.search(r'\d+', text)
    if match:
        first_number = match.group()
        return first_number
    return None
    

def zero_shot_cot_answer_extract(text:str):
    """gsm8k"""
    ans = re.findall(r"\\boxed{(\d+)}", text)
    if not ans:
        ans = find_last_number_in_last_sentence(text) 
    return floatify_ans(ans)

def few_shot_cot_answer_extract(text:str):
    """gsm8k"""
    cot = None
    cot_ans = None
    if " So the answer is:" in text:
        cot = text.split(" So the answer is:")[0]
        pattern = r'So the answer is: .*?(\d+)'
        match = re.findall(pattern, text)
        if match:
            cot_ans = match[0]
            cot_ans = floatify_ans(cot_ans)
    else: # 匹配不成功将返回最后一个数字
        if text != "" and text[-1] == '.':
            cot = text[:-1]
        else:
            cot = text
        match = re.findall(r'\-?\d+\.\d+|\-?\d+', text)
        if match:
            cot_ans = match[-1]
            cot_ans = floatify_ans(cot_ans)
    return cot, cot_ans

def get_precision(gt_ans: float) -> int:
    precision = 5
    if '.' in str(gt_ans):
        precision = len(str(gt_ans).split('.')[-1])
    return precision

def finqa_equal(prediction: Union[bool, float, str],
                reference: Union[float, str],
                include_percentage: bool = True,
                is_close: float = False) -> bool:
    if prediction is None:
        return False
    elif type(prediction) == bool:
        # bool questions
        if prediction:
            return reference == 'yes'
        else:
            return reference == 'no'
    elif type(reference) == str or type(prediction) == str:
        # string questions
        return prediction == reference
    else:
        # number questions
        if include_percentage:
            gt_result = [reference / 100, reference, reference * 100]
        else:
            gt_result = [reference]
        for item in gt_result:
            try:
                if is_close:
                    if isclose(item, prediction, rel_tol=0.001):
                        return True
                precision = min(get_precision(prediction), get_precision(item))
                if round(prediction, precision) == round(item, precision):
                    return True
            except Exception:
                continue
        return False


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="huggyllama/llama-7b")
    parser.add_argument("--num-gpus", type=str, default="1")
    parser.add_argument("--num-samples", type=int, default=80)
    parser.add_argument("--max_gpu_memory", type=int, default=27)
    parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
    parser.add_argument("--data-path", type=str, default="./gsm8k")
    parser.add_argument("--output-path", type=str, default="./gsm8k_result")
    # parallel mode (split the dataset into multiple parts, inference by separate processes)
    parser.add_argument("--early-exit-layers", type=str, default="-1")
    parser.add_argument("--parallel", action="store_true")
    parser.add_argument("--total-shard", type=int, default=8)
    parser.add_argument("--shard-id", type=int, default=None)
    parser.add_argument("--do-rating", action="store_true")
    parser.add_argument("--is-chat", action="store_true")
    parser.add_argument("--mode", type=str, choices=["baseline", "cot-enhance"], default="baseline")
    parser.add_argument("--gpt3-config", type=str, default=None)
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--bias", action="store_true")
    parser.add_argument("--max-new-tokens", type=int, default=50)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--top_k", type=int, default=0)
    parser.add_argument("--temperature", type=float, default=0.9)
    parser.add_argument("--repetition_penalty", type=float, default=1.0)
    parser.add_argument("--relative_top", type=float, default=0.0)
    parser.add_argument("--relative_top_value", type=float, default=-1000.0)
    args = parser.parse_args()
    model_name = args.model_name
    num_gpus = args.num_gpus
    device = args.device

    if args.parallel:
        chunk_size = len(list_data_dict) // args.total_shard
        list_data_dict = list_data_dict[args.shard_id * chunk_size: (args.shard_id + 1) * chunk_size]
    
    llm = ContrastiveDecoding(model_name, device, args.max_gpu_memory, num_gpus=int(args.num_gpus))
    # stop_word_list = ["Q:"]
    # llm.set_stop_words(stop_word_list)

    ## 加载gsm8k数据集
    data_file = f"{args.data_path}/test.json"
    if os.path.exists(data_file):
        print("Loading data from {}".format(data_file))
        dataset = DatasetLoader.load_dataset(
            "json", data_files={"test": data_file})["test"]
    else:
        raise NotImplementedError(data_file)
    if args.debug:
        dataset = dataset.select(range(10))

    writer = open(f"{args.output_path}/result.json", 'w')
    cot_prompt = load_prompt(args.data_path, "few_shot_cot")
    attn_t = 1
    T = 0.5
    if args.mode == "cot-enhance":
        if "gemma-2b" in args.model_name.lower():
            attn_t = [0, {12: ([3, 7, 4, 1], 0.5), 14: ([0, 1, 2, 5], 0.5)}]
        if "gemma-7b" in args.model_name.lower():
            attn_t = [0, {18: (range(0, 16), T)}]
        if "llama-2-7b" in args.model_name.lower():
            attn_t = [0, {13: (range(0, 32), T), 14: (range(0, 32), T)}]
        if "llama-3-8b" in args.model_name.lower():
            attn_t = [0, {17: ([0, 1, 3, 4, 5, 6, 7, 9, 10, 12, 13, 14, 16, 17, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31], 0.5)}]
    print("use attn_t {}".format(attn_t))
    
    correct, wrong = 0, 0
    for idx, example in tqdm(enumerate(dataset)):
        example = {**{'idx': idx}, **example}
        problem = f'Question: {example["question"]}' + '\n'
        cot_context = problem + "Answer: Let's think step by step. " 
        cot_content = cot_prompt + cot_context
        cot_gen = llm.generate(input_text=cot_content, attention_temperature=attn_t, max_new_tokens=128)[0].split('\n\n')[0]
        pred_cot, prediction = few_shot_cot_answer_extract(cot_gen)
        gt_cot, gt_ans = example['answer'].split("####")  # GSM8k
        gt_cot, gt_ans = gt_cot.strip(), floatify_ans(gt_ans.strip())
        is_correct = finqa_equal(prediction, gt_ans)
        sample = {'question': example['question'], 'gt_cot': gt_cot, 'gt': gt_ans,
                  'cot_gen': cot_gen, 'pred_cot': pred_cot, 'pred': prediction}
        if is_correct:
            correct += 1
        else:
            wrong += 1
        print("idx: {} is {}. Accruacy {}".format(idx, is_correct, correct / (correct + wrong)))
        print("Output: answer = {}, Glod Ans: {}".format(sample['pred'], sample["gt"]))
        writer.write(json.dumps(sample) + '\n')
        writer.flush()
    writer.close()
    print("Accuracy : {}".format(correct / (correct + wrong)))


     