# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

"""
inference on levelup-v3 with llama 13B

"""

from typing import Tuple
import os
import sys
import torch
import fire
import time
import json
import random
import numpy as np
import pandas as pd
from pathlib import Path
import re
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from tqdm import tqdm
from llama import ModelArgs, Transformer, Tokenizer, LLaMA, FunctionLM
import wandb
from collections import Counter

from funchub.math import _add_, _subtract_, _multiply_, _divide_, _power_, _sqrt_, _log_, _ln_, \
    _sin_, _cos_, _tan_, _asin_, _acos_, _atan_, _factorial_, _floor_, _ceil_, _round_, _radians_, _degrees_, \
    _exp_, _choose_, _permutate_, _gcd_, _lcm_, _root_, _remainder_



def setup_model_parallel() -> Tuple[int, int]:
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))

    torch.distributed.init_process_group("nccl")
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size


def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int, func_load_path: str, func_dict: dict) -> FunctionLM:
    start_time = time.time()
    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
    # print(checkpoints)
    assert (
        world_size == len(checkpoints)
    ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
    ckpt_path = checkpoints[local_rank]
    print("Loading")
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(max_seq_len=2048, max_batch_size=1, **params)
    tokenizer = Tokenizer(model_path=tokenizer_path)
    model_args.vocab_size = tokenizer.n_words
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    model = Transformer(model_args).cuda().half()
    torch.set_default_tensor_type(torch.FloatTensor)
    model.load_state_dict(checkpoint, strict=False)
    size = ckpt_dir.split("/")[-1]
    funcmodel = FunctionLM(model, tokenizer, func_dict = func_dict, load_path=func_load_path)
    # generator = LLaMA(model, tokenizer)
    print(f"Loaded in {time.time() - start_time:.2f} seconds")
    return funcmodel

def func_embedding_inference(templates, case_idx, question, funcmodel: FunctionLM, temperature, top_p, max_gen_len, func_dict, return_top=0, stop_token=[13], func_stop_token=[29897, 3892], block_repetitive=False, max_func_call=100):

    funcmodel.inference_mode = "func_embedding"
    cur_generation = ""
    logs = []
    last_op = []
    if "funcgeneral" not in templates:
        templates["funcgeneral"] = templates["general"]
    try:
        results = []
        func_calls = []
        while True:
            if max_func_call == 0:
                break
            prompt = templates["funcgeneral"].replace("[QUESTION]", question) + cur_generation
            # disable_func = [last_op[0]] if last_op[0] == last_op[1] else []
            disable_func = []
            results = funcmodel.generate([prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, stop_token=stop_token, return_top=return_top, disable_func=disable_func)
            max_func_call -= 1
            if return_top > 0:
                results, token_log = results
                logs.append(token_log)
            endflag = True

            current_token = 0
            
            record_tokens = token_log[-1]
            # assert prompt in results[0]
            cur_generation = results[0].replace(templates["funcgeneral"].replace("[QUESTION]", question), "")
            print("results: ", cur_generation)
            for op in func_dict:
                if cur_generation.endswith(op+"(") or cur_generation.endswith(op+" <"):
                    endflag = False
                    
                    prompt = templates[op].replace("[QUESTION]", question) + cur_generation
                    
                    funcmodel.inference_mode = "baseline"

                    # print("stop token: ", func_stop_token)
                                
                    # print("prompt: ", prompt)
                    results = funcmodel.generate([prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, stop_token=func_stop_token, return_top=return_top)
                    # print("results: ", results)
 
                    funcmodel.inference_mode = "func_embedding"    
                    if return_top > 0:
                        results, token_log = results
                        logs.append(token_log)
                    # logs.append(token_log)
                    # assert prompt in results[0]
                    # print("generation results: ", results[0])
                    cur_generation = results[0].replace(templates[op].replace("[QUESTION]", question), "")
                    # print("func results: ", cur_generation)
                    # print("func results: ", cur_generation = results[0].replace(templates[op] + "Question: " + question + "\nAnswer:", "").split(op)[0] + str(res))
                    if "<" in op and "P" not in op:
                        args = cur_generation.split(op)[-1].replace("=", "").replace(">", "").replace("((", "(").replace("))", ")") # shouldn't have >, but there is one in case study
                    # remove any tokens other than numbers (including decimal or fraction) and parentheses in the args
                        args = args.replace("$", "")
                    # remove , in the args
                        if ", " in args:
                            args = args.replace(", ", ";").replace(",", "").replace(";", ", ")

                        args = args.replace(" ", "")
                        res = eval(f"_{op[1:-1]}_{args}")
                        if "(" not in args or ")" not in args:
                            raise Exception("invalid args")
                        func_calls.append(f"{op}{args} = {res}")
                        cur_generation = cur_generation.split(op)[0] + str(res)
                        print("overide results", cur_generation)
                    elif "P" in op:
                        endflag = True
                        break
                    else:
                        if block_repetitive:
                            last_op = last_op[-1:] + [op]
                        cur_line = op + cur_generation.split(op)[-1]
                        func_calls.append(cur_line)
                        print("current line::", cur_line)
                        if len(func_calls) > 2 and func_calls[-1] == func_calls[-2]:
                            print("repetitive func call!!")
                            endflag = True
                            break
                        pass
                    break
            if endflag:
                break


        log = {
            "case_idx": case_idx,
            "question": question,
            "func_calls": func_calls,
            "generation": cur_generation.replace("\n", "\\n").strip(),
            # need to return logs
            # "token_log": logs,
            "status": "success"
        }
            # f.write(json.dumps(log) + "\n")

    except Exception as e:
        # if local_rank == 0:
        log = {
            "case_idx": case_idx,
            "question": question,
            "func_calls": func_calls,
            "generation": cur_generation.replace("\n", "\\n").strip(),
            "status": str(e)
        }
    return log


def ground_baseline_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len, stop_token=[13]):
    return baseline_inference(templates, case_idx, question[0], funcmodel, temperature, top_p, max_gen_len, stop_token=stop_token, objs=question[1])

def baseline_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len, stop_token=[13], objs=[]):
    funcmodel.inference_mode = "baseline"
    cur_generation = ""
    try:
        prompt = templates["general"].replace("[QUESTION]", question) + cur_generation
        results = funcmodel.generate([prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, stop_token=stop_token, objs=objs)
        
        cur_generation = results[0].replace(templates["general"].replace("[QUESTION]", question), "")

        log = {
            "case_idx": case_idx,
            "question": question,
            "generation": cur_generation.replace("\n", "\\n").strip(),
            "status": "success"
        }
            # f.write(json.dumps(log) + "\n")

    except Exception as e:
        # if local_rank == 0:
        log = {
            "case_idx": case_idx,
            "question": question,
            "generation": cur_generation.replace("\n", "\\n").strip(),
            "status": str(e)
        }

    return log

def react_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len, stop_token=[13]):
    funcmodel.inference_mode = "react"
    # get func list
    func_map = list(funcmodel.func_dict.keys())

    cur_generation = ""
    try:
        results = []
        func_calls = []
        while True:
            prompt = templates["general"] + "Question: " + question + "\nAnswer:" + cur_generation
            results = funcmodel.generate([prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, stop_token=stop_token + [29897, 3892])
            cur_generation = results[0].replace(templates["general"] + "Question: " + question + "\nAnswer:", "")
            endflag = True

            if cur_generation.endswith(")") or cur_generation.endswith(")="):
                endflag = False

                # use pattern to extract args
                pattern = r"\<(.*?)\>\((.*?)\)\="

                args = ""

                matches = re.findall(pattern, cur_generation)

                if len(matches) == 0:
                    raise Exception("invalid args")
                else:
                    op, args = matches[-1]

                op = "<" + op.strip() + ">"
                args = "(" + args.strip() + ")"
                
                if op not in func_map:
                    raise Exception(f"invalid func -- {op}")

                if args == "":
                    raise Exception("invalid args")
                
                
                args = args.replace("=", "").replace(">", "").replace("((", "(").replace("))", ")")


                # remove , in the args
                if ", " in args:
                    args = args.replace(", ", ";").replace(",", "").replace(";", ", ")

                args = args.replace(" ", "")

                if "(" not in args or ")" not in args:
                    raise Exception("invalid args")

                # handle % and / in args
                if '%' in args or '/' in args:
                    temp = args.split("(")[1].split(")")[0].split(",")

                    for arg_i, arg in enumerate(temp):
                        # if have percentage, convert to decimal
                        if "%" in arg:
                            arg = arg.replace("%", "").strip()
                            arg = str(float(arg) / 100)
                        # if have fraction, convert to decimal
                        if "/" in arg:
                            numerator, denominator = [a.strip() for a in arg.split("/")]
                            arg = str(float(numerator) / float(denominator))
                        
                        temp[arg_i] = arg
                    
                    args = f"({', '.join(temp)})"
                
                try:
                    res = eval(f"_{op[1:-1]}_{args}")
                    func_calls.append(f"{op}{args} = {res}")
                    cur_generation = cur_generation + str(res)
                    # only generate the next token
                    # disable all the numbers
                    prompt = templates["func"].replace("[QUESTION]", question) + cur_generation
                    results = funcmodel.generate([prompt], max_gen_len=1, temperature=temperature, top_p=top_p, stop_token=[13],
                                                    disable_token = [29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929]) # disable all the numbers
               
                    cur_generation = results[0].replace(templates["func"].replace("[QUESTION]", question), "")
                except Exception as e:
                    raise Exception(f"error -- {e}")

            if endflag:
                break

        log = {
            "case_idx": case_idx,
            "question": question,
            "func_calls": func_calls,
            "generation": cur_generation,
            "status": "success"
        }

    except Exception as e:
        # if local_rank == 0:
        log = {
            "case_idx": case_idx,
            "question": question,
            "generation": cur_generation,
            "status": str(e)
        }
    return log 

def main(ckpt_dir: str, tokenizer_path: str, temperature: float = 0, top_p: float = 0.95, mode: str = "baseline", dataset = "original", return_top: int = 5, logits_bias: float = 0, func_load_path: str = "None", self_consistency_k: int =0, save_name="None"):
    # set random seed
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(1)
    np.random.seed(1)

    size = ckpt_dir.split("/")[-1]

    local_rank, world_size = setup_model_parallel()
    if local_rank > 0:
        sys.stdout = open(os.devnull, 'w')
    
    templates = {}
    

    # print(templates)
    if dataset == "original":
        for name in os.listdir("data/gsm8k/template"):
            with open(f"data/gsm8k/template/{name}") as f:
                templates[name.split("_")[-1].replace(".txt", "")] = f.read()
        with open("data/gsm8k/gsm8k.json") as f:
            data = json.load(f)
            test_cases = [i["question"] for i in data["test"]]
        max_gen_len = 512
        func_dict = {
            "<add>": 0,
            "<subtract>": 1,
            "<multiply>": 2,
            "<divide>": 3,
            }
        stop_token = [13]
        func_stop_token = [29897, 3892]
        block_repetitive = False
        max_func_call = 100
        
    elif dataset.startswith("levelup"):
        # version is the last two characters of the dataset name
        version = dataset[-2:]
        for name in os.listdir("data/gsm8k/template"):
            with open(f"data/gsm8k/template/{name}") as f:
                templates[name.split("_")[-1].replace(".txt", "")] = f.read()
        with open(f"data/gsm8k_enhanced/gsm_8k_enhanced_{version}.json") as f:
            data = [json.loads(line) for line in f.readlines()]
            
        raw_test_cases = [i["question"] for i in data]
        enhanced_v = [i["enhanced_v"] for i in data]
        
        test_cases = []
        for v, q in zip(enhanced_v, raw_test_cases):
            # parse {v_1}, {v_2}, ... in q and fill with v
            for i in range(len(v)):
                q = q.replace(f"{{v_{i+1}}}", str(v[i]))

            test_cases.append(q)

        max_gen_len = 512
        func_dict = {
            "<add>": 0,
            "<subtract>": 1,
            "<multiply>": 2,
            "<divide>": 3,
        }
        stop_token = [13]
        func_stop_token = [29897, 3892]
        block_repetitive = False
        max_func_call = 100

    elif dataset == "lama-template" or dataset == "lama":
        for name in os.listdir("data/gsm8k/template"):
            with open(f"data/gsm8k/template/{name}") as f:
                templates[name.split("_")[-1].replace(".txt", "")] = f.read()
        file = "data/lama-t-rex/last_token_template_shuffled.json" if dataset == "lama-template" else "data/lama-t-rex/last_token_data_shuffled.json"
        with open(file) as f:
            data = json.load(f)
            test_cases = [i[0] for i in data]
        max_gen_len = 5
        func_dict = json.load(open("data/lama-t-rex/my_relations.json"))
        func_dict = {f"<{r['relation']}>": ind for ind, r in enumerate(func_dict)}
        stop_token = [13]
        func_stop_token = [29897, 3892]
        max_func_call = 1
        block_repetitive = False

    elif dataset.startswith("kamel"):
        n_first = int(dataset.split("_")[-1])
        for name in os.listdir("data/kamel/template"):
            with open(f"data/kamel/template/{name}") as f:
                templates[name.split("_")[-1].replace(".txt", "")] = f.read()
        if "reactood" in dataset:
            try:
                templates["general"] = templates[f"ood{n_first}"]
            except:
                print("using 60")
                templates["general"] = templates["ood60"]
        elif "react" in dataset:
            try:
                templates["general"] = templates[f"{n_first}"]
            except:
                print("using 30")
                templates["general"] = templates["30"]
        with open(f"data/kamel/test_first_{n_first}.json") as f:
            data = json.load(f)
            test_cases = [i["question"] for i in data]
        max_gen_len = 30
        func_desc = json.load(open("data/kamel/api_desc.json"))
        reversed_func_desc = {v: k for k, v in func_desc.items()}
        # func_dict = {f"<{r}>": ind for ind, r in enumerate(func_dict)}
        func_dict = json.load(open("data/kamel/func_dict.json"))
        func_dict = {f"<{reversed_func_desc[k]}>": v for k, v in func_dict.items()}
        func_dict = {k: v for k, v in func_dict.items() if v < n_first}
        print(len(func_dict))
        stop_token = [13]
        func_stop_token = [29897, 3892]
        block_repetitive = False
        max_func_call = 1

    elif dataset == "vh":
        from vh_eval import get_desc
        with open("data/vh/legal_test.json") as f:
            file_list = json.load(f)
        test_cases = []
        for script_file, state_file in file_list:
            test_cases.append(get_desc(graph_file_name=state_file, script_file_name=script_file))
        func_dict = {}
        max_gen_len = 96
        max_func_call = 8
        for name in os.listdir("data/vh/template"):
            with open(f"data/vh/template/{name}") as f:
                templates[name.split("_")[-1].replace(".txt", "")] = f.read()
        with open("data/vh/vh_func_dict.json") as f:
            func_dict = json.load(f)
        stop_token = [[13, 13]]  # stop after two "\n"
        func_stop_token = [13]
        block_repetitive = True

    elif dataset == "vh1":
        
        with open("data/vh/May3_func_dict.json") as f:
            func_dict = json.load(f)
        from vh_eval import get_desc
        assert mode in ["special_func", "baseline"]
        with open("data/vh/legal_test_v1.json") as f:
            file_list = json.load(f)


        if mode == "special_func":
            test_cases = []
            with open("data/vh/template/vh_specialgood.txt") as f:
                template = f.read()
            
            for script_file, state_file in file_list:
                with open(script_file) as f:
                    script = f.read()
                    title = script.split("\n")[0]
                    goal = script.split("\n")[1]

                    desc = get_desc(graph_file_name=state_file, script_file_name=script_file)
                    obj_list = re.search(r"The objects I can manipulate are (.*?)\.", desc).group(1)
                    obj_list = eval(obj_list)
                    print(len(obj_list))
                    obj_list = [f"<{o}>" for o in obj_list]
                    discard_list = [o for o in func_dict if o not in obj_list and o.startswith("<")]
                    test_cases.append((template.replace("[QUESTION]", title + "\n" + goal + "\n"), discard_list))

            print(test_cases[0][0]+"[START]")
            print(test_cases[0][1])
        elif mode == "baseline":
            test_cases = []
            with open("data/vh/template/vh_general.txt") as f:
                template = f.read()
            templates = {"general": template}
            for script_file, state_file in file_list:
                with open(script_file) as f:
                    script = f.read()
                    
                    existing_obj_list = []

                    for fun in func_dict:
                        if fun.startswith("<"):
                            existing_obj_list.append(fun[1:-1])

                    desc = get_desc(graph_file_name=state_file, script_file_name=script_file, obj_list=existing_obj_list)
                    
                    test_cases.append(desc)

            print(test_cases[0])

        stop_token = [[13,13]]
        max_gen_len = 96
        max_func_call = 32


    elif dataset == "vh2":
        from vh_eval import get_desc
        assert mode in ["special_func", "baseline", "ground_baseline"]
        with open("data/vh/legal_test_v2.json") as f:
            file_list = json.load(f)
        with open("data/vh/func_dict_v4.json") as f:
            func_dict = json.load(f)

        if mode == "special_func":

            test_cases = []
            with open("data/vh/template/vh_special_v4.txt") as f:
                template = f.read()
            existing_obj_list = []

            for fun in func_dict:
                if fun.startswith("<"):
                    existing_obj_list.append(fun[1:-1])
            for script_file, state_file in file_list:
                with open(script_file) as f:
                    script = f.read()
                    title = script.split("\n")[0]
                    goal = script.split("\n")[1]

                    desc = get_desc(graph_file_name=state_file, script_file_name=script_file, obj_list=existing_obj_list)
                    obj_list = re.search(r"The objects I can manipulate are (.*?)\.", desc).group(1)
                    obj_list = eval(obj_list)
                    print(len(obj_list))
                    obj_list = [f"<{o}>" for o in obj_list]
                    discard_list = [o for o in func_dict if o not in obj_list and o.startswith("<")]
                    test_cases.append((template.replace("[QUESTION]", desc), discard_list))

            print(test_cases[0][0]+"[START]")
            print(test_cases[0][1])

        elif mode == "baseline":
            test_cases = []
            
            with open("data/vh/template/vh_baseline_v2.txt") as f:
                template = f.read()
            templates = {"general": template}

            

            for script_file, state_file in file_list:
                with open(script_file) as f:
                    script = f.read()
                    existing_obj_list = []
                    for fun in func_dict:
                        if fun.startswith("<"):
                            existing_obj_list.append(fun[1:-1])

                    desc = get_desc(graph_file_name=state_file, script_file_name=script_file, obj_list=existing_obj_list)
                    
                    test_cases.append(desc)

            print(test_cases[0])

        elif mode == "ground_baseline":
            test_cases = []
            
            with open("data/vh/template/vh_baseline_v2.txt") as f:
                template = f.read()
            templates = {"general": template}
            for script_file, state_file in file_list:
                with open(script_file) as f:
                    script = f.read()
                    existing_obj_list = []
                    for fun in func_dict:
                        if fun.startswith("<"):
                            existing_obj_list.append(fun[1:-1])

                    desc = get_desc(graph_file_name=state_file, script_file_name=script_file, obj_list=existing_obj_list)
                    
                    obj_list = re.search(r"The objects I can manipulate are (.*?)\.", desc).group(1)
                    obj_list = eval(obj_list)
                    
                    test_cases.append((desc, obj_list))

            print(test_cases[0])

        stop_token = [[13,13]]
        max_gen_len = 96
        max_func_call = 32
        

    elif dataset == "mathqa":
        for name in os.listdir("data/mathqa/template"):
            with open(f"data/mathqa/template/{name}") as f:
                templates[name.split("_")[-1].replace(".txt", "")] = f.read()
        with open("data/mathqa/test_v2.json") as f:
            data = [json.loads(line) for line in f]
            test_cases = [i["question"]+'\nAnswer Choices: '+i['options'] for i in data]
        max_gen_len = 1024

        func_dict = {
            "<add>": 0,
            "<subtract>": 1,
            "<multiply>": 2,
            "<divide>": 3,
        }
    elif dataset == "funcqa":
        for name in os.listdir("data/funcqa/template"):
            with open(f"data/funcqa/template/{name}") as f:
                templates[name.split("_")[-1].replace(".txt", "")] = f.read()

        with open("data/funcqa/funcqa.json") as f:
            data = json.load(f)
    
        test_cases = [i["question"] for i in data]
        
        # # funcqa is a bit different, it is a csv file
        # df = pd.read_csv("data/funcqa/funcqa.csv")
        # test_cases = []
        # for _, row in df.iterrows():

        #     # if row["numbers"] is not "nan"
        #     if pd.notna(row["numbers"]):
        #         num_list = [n.strip() for n in row["numbers"].split(",")]
        #     else:
        #         num_list = []

        #     q = row["question"]

        #     for i, n in enumerate(num_list):
        #         q = q.replace(f"[NUM_{i+1}]", n)
            
        #     test_cases.append(q)
        
        max_gen_len = 768

        func_dict = {
            "<add>": 0,
            "<subtract>": 1,
            "<multiply>": 2,
            "<divide>": 3,
            "<power>": 4,
            "<sqrt>": 5,
            "<log>": 6,
            "<root>": 7,
            "<sin>": 8,
            "<cos>": 9,
            "<tan>": 10,
            "<asin>": 11,
            "<acos>": 12,
            "<atan>": 13,
            "<radians>": 14,
            "<degrees>": 15,
            "<lcm>": 16,
            "<gcd>": 17,
            "<ln>": 18,
            "<exp>": 19,
            "<choose>": 20,
            "<factorial>": 21,
            "<remainder>": 22,
            "<permutate>": 23
        }



    else:
        raise NotImplementedError


    # if local_rank == 0:
    #     wandb.init(project="funcllama", name=f"gsm-{world_size}", config={
    #         "lr": lr
    #     })

    funcmodel = load(ckpt_dir, tokenizer_path, local_rank, world_size, func_load_path=func_load_path, func_dict=func_dict)
    funcmodel.set_bias(logits_bias)
    funcmodel.eval()

    # find the longest common prefix of first two prompts
    # prefix = os.path.commonprefix(prompts[:2])
    # print(f"Common prefix: {prefix}")
    # print(f"There are {len(prompts)} prompts in total.")

    # only update tokens with gradients required
    # optimizer = torch.optim.Adam([p for p in funcmodel.parameters() if p.requires_grad], lr=1e-4)


    # func_load_name = func_load_path.split("/")[-1].replace(".pt", "")
    for case_idx, question in tqdm(enumerate(test_cases)):
        if mode == "func_embedding":
            log = func_embedding_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len, func_dict, return_top, stop_token=stop_token, func_stop_token=func_stop_token, block_repetitive=block_repetitive, max_func_call=max_func_call)

        elif mode == "baseline":
            log = baseline_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len, stop_token=stop_token)
            
        elif mode == "ground_baseline":
            log = ground_baseline_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len, stop_token=stop_token)

        elif mode == "react":
            log = react_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len, stop_token=stop_token, func_stop_token=func_stop_token)
        
        elif mode == "special_func":
            log = special_func_inference(case_idx, question, funcmodel, temperature, top_p, func_dict,max_func_call=max_func_call)
        if local_rank == 0:
            try:
                func_model_name = func_load_path.split('/')[-1].split('.')[0]
            except:
                func_model_name = func_load_path

            output_dir = f"final_outputs/{dataset}"
            os.makedirs(output_dir, exist_ok=True)

            if self_consistency_k > 1:
                with open(f"{output_dir}/inference-{size}-{func_model_name}-{mode}-{dataset}-{logits_bias}-{self_consistency_k}.jsonl", "a") as f:
                    f.write(json.dumps(log) + "\n")
            else:
                with open(f"{output_dir}/inference-{save_name}-{size}-{func_model_name}-{mode}-{dataset}-{logits_bias}.jsonl", "a") as f:
                    f.write(json.dumps(log) + "\n")


if __name__ == "__main__":
    fire.Fire(main)