from torch.utils.data import Dataset, DataLoader
import os
import json
import itertools
import torch
from task import Task
from numbers_class import Domain, NumberBasic
from typing import Literal, get_args
import tqdm
from peft import PeftModel
from task import Task
from collections import defaultdict
from transformers import AutoModelForCausalLM, AutoTokenizer

class NLDataset(Dataset):
    
    def __init__(self, path: str, train_or_test: Literal["train", "test"], num_each: int | None = None, random_seed: int = 20222943, tokenizer=None, one_digit_converter_file: str | None = None):
        self.path = path
        assert train_or_test in ["train", "test"], "train_or_test must be either 'train' or 'test'"
        self.train = train_or_test == "train"
        self.data_file: dict[str, dict[int, list[str]]] = json.load(open(path, 'r'))
        if num_each is not None:
            self.data_file = self.truncate_data(num_each, random_seed)
        self.num_each = num_each
        self.tasks = list(self.data_file.keys())
        
        if one_digit_converter_file is not None:
            self.one_digit_converter = json.load(open(one_digit_converter_file))
        else:
            self.one_digit_converter = None
        
        # create the global index for each task and the digits in each task
        self.task_digit_indices: list[list[int]] = []
        for task in self.tasks:
            self.task_digit_indices.append([len(self.data_file[task][digit]) for digit in self.data_file[task]])
        self.task_indices = [0] + list(itertools.accumulate([sum(digit_indices) for digit_indices in self.task_digit_indices]))
        self.tokenizer = tokenizer
        
    def __len__(self):
        return sum([sum(digit_indices) for digit_indices in self.task_digit_indices])
    
    def truncate_data(self, num_each: int, random_seed: int):
        import random
        random_rng = random.Random(random_seed)
        new_data_file: dict[str, dict[int, list[str]]] = {}
        for task, digits in self.data_file.items():
            new_data_file[task] = {}
            for digit, data in digits.items():
                new_data_file[task][digit] = random_rng.sample(data, min(num_each, len(data)))
        return new_data_file
    
    def __getitem__(self, idx: int) -> tuple[str, int, str, str] | torch.Tensor:
        # find first task idx larger than idx, then this idx - 1 is the task idx
        task_idx = 0
        while idx >= self.task_indices[task_idx + 1]:
            task_idx += 1
        task = self.tasks[task_idx]
        idx -= self.task_indices[task_idx]
        digits = list(self.data_file[task].keys())
        digit_idx = 0
        while idx >= self.task_digit_indices[task_idx][digit_idx]:
            idx -= self.task_digit_indices[task_idx][digit_idx]
            digit_idx += 1
        digit = digits[digit_idx]
        data = self.data_file[task][digit][idx]
        assert "=" in data
        
        # If in train, return the tokenized data
        if self.train:
            if self.one_digit_converter is None:
                return self.tokenizer(data, return_tensors="pt")["input_ids"].squeeze(0)
            else:
                inputs = self.tokenizer(data)
                assert all(i == 1 for i in inputs["attention_mask"])
                new_tokens: list[int] = []
                for index in inputs["input_ids"]:
                    if str(index) not in self.one_digit_converter:
                        new_tokens.append(index)
                    else:
                        new_tokens.extend(self.one_digit_converter[str(index)])
                return torch.tensor(new_tokens)
                
        groundtruth = data.split("=")[1].strip()
        data = data.split("=")[0] + "="
        return task, int(digit), data, groundtruth
    
def collate_fn(batch: list[tuple[str, int, str, str]]) -> tuple[list[str], list[int], list[str], list[str]]:
    tasks, digits, data, groundtruth = zip(*batch)
    return list(tasks), list(digits), list(data), list(groundtruth)

class StringMetrics:
    def __init__(self):
        self._set_default_read_direct()
        
    def _set_default_read_direct(self) -> None:
        self.default_read_direct: dict[Domain | Literal["int"], list[Literal['left', 'right']]] = {}
        for domain in get_args(Domain):
            self.default_read_direct[domain] = NumberBasic.get_subclass(domain).default_read_direct()
        self.default_read_direct["int"] = ["right"]
        
    def _clean_str_and_part(self, str_: str, number_type: Domain | Literal['int']) -> list[str]:
        try:
            if number_type == "Integer" or number_type == "int":
                return ["".join([c for c in str_ if c.isdigit()])]
            elif number_type == "Float":
                parts = str_.split(".")
                return ["".join([c for c in parts[0] if c.isdigit()]), "".join([c for c in parts[1] if c.isdigit()])]
            elif number_type == "Fraction":
                parts = str_.split("/")
                return ["".join([c for c in parts[0] if c.isdigit()]), "".join([c for c in parts[1] if c.isdigit()])]
            elif number_type == "ScientificNotation":
                parts = str_.split("e")
                parts = parts[0].split(".") + [parts[1]]
                return ["".join([c for c in parts[0] if c.isdigit()]), "".join([c for c in parts[1] if c.isdigit()]), "".join([c for c in parts[2] if c.isdigit()])]
            else:
                raise ValueError(f"Invalid number type {number_type}")
        except IndexError as e:
            return [""] * 1 if number_type == "Integer" else ([""] * 2 if number_type in ["Float", "Fraction"] else [""] * 3)
        
    def __call__(self, pred_str: str, gt_str: str, expected_type: Domain | Literal['int']):
        pred_parts = self._clean_str_and_part(pred_str, expected_type)
        gt_parts = self._clean_str_and_part(gt_str, expected_type)
        exact_match = int(pred_parts == gt_parts)
        read_direct = self.default_read_direct[expected_type]
        digit_match_c = 0
        for pred_p, gt_p, direct in zip(pred_parts, gt_parts, read_direct):
            if direct == "right":
                pred_p = pred_p[::-1]
                gt_p = gt_p[::-1]
            digit_match_c += sum([int(p == g) for p, g in zip(pred_p, gt_p)])
        digit_match = digit_match_c / sum([len(p) for p in gt_parts])
        dlength = abs(sum(len(p) for p in pred_parts) - sum(len(p) for p in gt_parts))
        return {
            "exact_match": exact_match,
            "digit_match": digit_match,
            "dlength": dlength
        }
        
        
class MetricsRecorder:
    def __init__(self) -> None:
        self.value: dict[str, dict[str, defaultdict[int, float]]] = {} # metric_name -> task_name -> digit -> value
        self.count: dict[str, dict[str, defaultdict[int, int]]] = {}
        
    def _record_dict(self, metric_name: str, task_name: str, digit: int, value: float):
        if metric_name not in self.value:
            self.value[metric_name] = {}
            self.count[metric_name] = {}
        if task_name not in self.value[metric_name]:
            self.value[metric_name][task_name] = defaultdict(float)
            self.count[metric_name][task_name] = defaultdict(int)
        self.value[metric_name][task_name][digit] += value
        self.count[metric_name][task_name][digit] += 1
    
    def process(self, processed_file_path: str):
        """
        processed_file_path: str, the path to the processed file, which should be a jsonl file. Each line is dict{"task": str, "digit": int, "generated_text": list[str], "groundtruth": str}
        """
        metrics = StringMetrics()
        with open(processed_file_path, 'r') as rf:
            for line in tqdm.tqdm(rf):
                data = json.loads(line)
                task_name = data["task"]
                _, domain_a, domain_b, domain_output = Task.name2components(task_name)
                digit = data["digit"]
                gt = data["groundtruth"]
                for i, generated_text in enumerate(data["generated_text"]):
                    pred = self.retrieve_answer(generated_text, output_domain = domain_output)
                    metrics_result = metrics(pred, gt, domain_output)
                    for key, value in metrics_result.items():
                        self._record_dict(
                            metric_name=key,
                            task_name=task_name,
                            digit=digit,
                            value=value
                            )
        print('Done!')
        
    def _get_task_digit_gt_for_gpt_batches(self, dataset: NLDataset | None = None, batches_dir: str | None = None) -> dict[int, tuple[str, int, str]]:
        return_dict = {}
        
        if dataset is None:
            assert batches_dir is not None
            for file in os.listdir(batches_dir):
                if not file.endswith(".jsonl"):
                    continue
                with open(os.path.join(batches_dir, file), 'r') as rf:
                    for line in tqdm.tqdm(rf):
                        data = json.loads(line)
                        custom_id = int(data["custom_id"].split("-")[-1])
                        return_dict[custom_id] = (data["metadata"]["task_name"], int(data["metadata"]["digit"]), data["metadata"]["groundtruth"])
        else:
            for i, (task, digit, data, groundtruth) in enumerate(dataset):
                return_dict[i] = (task, digit, groundtruth)
        
        return return_dict                        
                        
        
    def process_gpt_generated(self, generated_file_dir: str, dataset_file: str | None = None):
        import re
        self.gpt_cache: dict[int, str] = {}
        read_tokens = 0
        generate_tokens = 0
        if dataset_file is not None:
            # load the dataset file
            dataset = NLDataset(dataset_file, train_or_test="test", num_each=100, random_seed=20222943)
        else:
            dataset = None
            
        task_digit_gt: dict[int, tuple[str, int, str]] = self._get_task_digit_gt_for_gpt_batches(dataset, "benchmark/gpt_batches")
            
        metrics = StringMetrics()
        
        # process the generated files
        for file in os.listdir(generated_file_dir):
            if not file.endswith(".jsonl"):
                continue
            with open(os.path.join(generated_file_dir, file), 'r') as rf:
                for line in tqdm.tqdm(rf):
                    data = json.loads(line)
                    generated_text = data["response"]["body"]["choices"][0]["message"]["content"]
                    custom_id = int(re.match(r"request-(\d+)", data["custom_id"]).group(1))
                    output_domain = Task.name2components(task_digit_gt[custom_id][0])[3]
                    pred = self.retrieve_answer(generated_text, output_domain=output_domain, start_answer = "The answer is")
                    self.gpt_cache[custom_id] = pred
                    read_tokens += data["response"]["body"]["usage"]["prompt_tokens"]
                    generate_tokens += data["response"]["body"]["usage"]["completion_tokens"]
        
        # process the cache
        for custom_id in range(len(self.gpt_cache)):
            task, digit, gt = task_digit_gt[custom_id]
            pred = self.gpt_cache[custom_id]
            output_domain = Task.name2components(task)[3]
            metrics_result = metrics(pred, gt, output_domain)
            for key, value in metrics_result.items():
                self._record_dict(
                    metric_name=key,
                    task_name=task,
                    digit=digit,
                    value=value
                    )
                
        print("Done! The number of read tokens is", read_tokens, "and the number of generated tokens is", generate_tokens)
        
    def retrieve_answer(self, text: str, output_domain: Domain | Literal["int"], start_answer: str = " = ") -> str:
        try:
            text = text[text.index(start_answer) + len(start_answer):]
        except ValueError as e:
            text = text
        text = text.strip()
        if output_domain == "Integer" or output_domain == "int":
            pattern = r"\d+"
        elif output_domain == "Float":
            pattern = r"\d+\.\d+"
        elif output_domain == "Fraction":
            pattern = r"\d+/\d+"
        elif output_domain == "ScientificNotation":
            pattern = r"\d+\.\d+[eE][+-]?\d+"
        else:
            raise ValueError(f"Invalid output domain {output_domain}")
        import re
        match = re.match(pattern, text)
        if match is None:
            return ""
        text = match.group()
        # remove "+" "-" in text and change E to e
        text = text.replace("+", "").replace("-", "").replace("E", "e")
        return text
        
    
    def save(self, save_path: str):
        import pickle
        with open(save_path, 'wb') as wf:
            pickle.dump((self.value, self.count), wf)
        print(f"Save results in {save_path}")
        
    def load(self, load_path: str):
        import pickle
        with open(load_path, 'rb') as rf:
            self.value, self.count = pickle.load(rf)
        print(f"Load results from {load_path}")
        
    def statistics(self, output_file: str | None = None):
        output = {}            
        for metric, task_dict in self.value.items():
            if metric == "exact_match":
                has_performance_thre = 0.1
                well_learned_thre = 0.9
                larger_is_better = True
            elif metric == "digit_match":
                has_performance_thre = 0.5
                well_learned_thre = 0.9
                larger_is_better = True
            else:
                well_learned_thre = 0.1
                has_performance_thre = 1
                larger_is_better = False
                
            record_dict = {}
            for task, digit_dict in task_dict.items():
                well_learned_digit = 0
                has_performance_digit = 0
                
                max_digit = max(digit_dict.keys())
                if max_digit == 21:
                    max_digit = 20 # some bug, to_float has some 21 digit in the dataset
                assert max_digit == 20 or max_digit == 100, f"The maximum digit of a task should be either 20 or 100, but find {max_digit} in task {task}."
                if max_digit == 20:
                    thre = [0, 5, 9, 15 ,21]
                else:
                    thre = [0, 11, 21, 61, 101]
                    
                count_cc = 0 # record for each two range: in-domain and out-domain
                value_cc = 0 # record for each two range: in-domain and out-domain
                averages_range = []
                averages_two_range = []
                for i, (min_digit, max_digit) in enumerate(zip(thre[:-1], thre[1:])):
                    count_c = 0
                    value_c = 0
                    if i == 2:
                        count_cc = 0 # record for each two range
                        value_cc = 0 # record for each two range
                    for digit in range(min_digit, max_digit):
                        if digit in digit_dict:
                            count_c += self.count[metric][task][digit]
                            value_c += self.value[metric][task][digit]
                            count_cc += self.count[metric][task][digit]
                            value_cc += self.value[metric][task][digit]
                            average_digit = self.value[metric][task][digit] / self.count[metric][task][digit]
                            if larger_is_better:
                                if average_digit >= well_learned_thre:
                                    well_learned_digit = max(well_learned_digit, digit)
                                if average_digit >= has_performance_thre:
                                    has_performance_digit = max(has_performance_digit, digit)
                            else:
                                if average_digit <= well_learned_thre:
                                    well_learned_digit = max(well_learned_digit, digit)
                                if average_digit <= has_performance_thre:
                                    has_performance_digit = max(has_performance_digit, digit)
                    if count_c == 0:
                        continue
                    average_range = value_c / count_c
                    averages_range.append(average_range) # length: 4
                    if i == 1 or i == 3:
                        average_two_range = value_cc / count_cc
                        averages_two_range.append(average_two_range) # length: 2
                record_dict[task] = {
                    "well_learned_digit": well_learned_digit,
                    "has_performance_digit": has_performance_digit,
                    "in_domain": averages_two_range[0],
                    "out_domain": averages_two_range[1],
                    "short_range": averages_range[0],
                    "medium_range": averages_range[1],
                    "long_range": averages_range[2],
                    "very_long_range": averages_range[3]
                }
            output[metric] = record_dict
            
            if output_file is not None:
                with open(output_file, 'w') as wf:
                    json.dump(output, wf, indent=2)
        return output
            
    def report(self, digit_range: int | tuple[int, int | None] | None = None, task_name: str | None = None, output_file: str | None = None):
        if digit_range is None:
            every = True
            digit_range = (0, None)
        else:
            every = False
        if isinstance(digit_range, int):
            digit_range = (digit_range, digit_range + 1)
        if digit_range[1] is None:
            digit_range = (digit_range[0], 10000000000)
        assert isinstance(digit_range, tuple) and len(digit_range) == 2, "Invalid digit range"
        if output_file is not None:
            output_file = open(output_file, 'w')
        for metric_name, task_dict in self.value.items():
            print(f"Metric: {metric_name}", file=output_file)
            if task_name is None:
                for task_name_, digit_dict in task_dict.items():
                    print(f"    Task: {task_name_}", file=output_file)
                    value_c = 0
                    count_c = 0
                    
                    if every:
                        for digit, value in digit_dict.items():
                            value_c = value
                            count_c = self.count[metric_name][task_name_][digit]
                            print(f"        Digit: {digit}; Average: {value_c / count_c}", file=output_file)
                    
                    else:
                    
                        for digit, value in digit_dict.items():
                            if digit_range[0] <= digit < digit_range[1]:
                                value_c += value
                                count_c += self.count[metric_name][task_name_][digit]
                        print(f"        Digit Range: [{digit_range[0]}, {digit_range[1]}); Average: {value_c / count_c}", file=output_file)
            else:
                print(f"    Task: {task_name}", file=output_file)
                value_c = 0
                count_c = 0
                if every:
                    for digit, value in task_dict[task_name].items():
                        value_c = value
                        count_c = self.count[metric_name][task_name][digit]
                        print(f"        Digit: {digit}; Average: {value_c / count_c}", file=output_file)
                else:
                    for digit, value in task_dict[task_name].items():
                        if digit_range[0] <= digit < digit_range[1]:
                            value_c += value
                            count_c += self.count[metric_name][task_name][digit]
                    print(f"        Digit Range: [{digit_range[0]}, {digit_range[1]}); Average: {value_c / count_c}", file=output_file)
    
def gpt_query(data: str, answer_domain: str, idx: int, model: str, groundtruth: str, task_name: str, digit: int) -> dict:
    system_message = \
    """You are a capable math assistant.
    Return your solution without any process in the format: The answer is [YOUR ANSWER].
    The final answer must strictly match the format """
    format_message = {
        "Integer": r'r"\d+"',
        "Float": r'r"\d+\.\d+"',
        "Fraction": r'r"\d+/\d+"',
        "ScientificNotation": r'r"\d+\.\d+e\d+"',
    }
    
    system_message += format_message[answer_domain]
    messages = [{"role": "system", "content": system_message},
                {"role": "user", "content": data}]
    return {
        "custom_id": f"request-{idx}",
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": {
            "model": model,
            "messages": messages,
            "max_tokens": 256,
        },
        # "metadata": {
        #     "groundtruth": groundtruth,
        #     "digit": digit,
        #     "task_name": task_name
        # }
    }
    
def create_batches_gpt(dataset: NLDataset, model: str, batch_size: int = 40000):
    batch_idx = 0
    os.makedirs("benchmark/gpt_batches", exist_ok=True)
    for i, (task, digit, data, groundtruth) in enumerate(dataset):
        if i % batch_size == 0:
            batch = []
        answer_domain = (Task.name2components(task)[-1]).strip()
        batch.append(gpt_query(data=data, answer_domain=answer_domain, idx=i, model=model, groundtruth=groundtruth, task_name=task, digit=digit))
        if i % batch_size == batch_size - 1:
            with open(f"benchmark/gpt_batches/batch_{batch_idx}.jsonl", 'w') as f:
                for query in batch:
                    f.write(json.dumps(query) + '\n')
            batch_idx += 1
    if len(batch) > 0:
        with open(f"benchmark/gpt_batches/batch_{batch_idx}.jsonl", 'w') as f:
            for query in batch:
                f.write(json.dumps(query) + '\n')
                batch_idx += 1
    return batch_idx
                
def main_gpt_test(create_batches: bool = False, create_requests: bool = False, num_batches: int | None = None):
    # model = "gpt-4o-mini-2024-07-18"
    model = "gpt-4o-2024-08-06"
    if create_batches:
        dataset = NLDataset("benchmark/tasks/test.json", train_or_test="test", num_each=100, random_seed=20222943)
        num_batches = create_batches_gpt(dataset, model=model)
    if create_requests:
        assert num_batches is not None, "If not creating batches, please provide the number of batches"
        from openai import OpenAI
        client = OpenAI(api_key="")
        with open(f"record_gpt_requests_{model}.txt", "w") as wf:
            for i in range(num_batches):
                batch_input_file = client.files.create(
                    file = open(f"benchmark/gpt_batches/batch_{i}.jsonl", 'rb'),
                    purpose="batch"
                )
                batch_input_file_id = batch_input_file.id
                response = client.batches.create(
                    input_file_id=batch_input_file_id,
                    endpoint = "/v1/chat/completions",
                    completion_window = "24h"
                )
                wf.write(f"Batch idx {i} with id {response.id}\n")
            
def main_test(dataset_path: str, model_name_or_path: str, batchsize: int, num_each: int | None = None, continue_: bool = False, checkpoint: str | None = None, load_in_4_bit: bool = False, reverse: bool = False, pad: bool = False, one_digit_tokenizer: bool = False, nope: bool = False, suffix: str | None = None):
    if nope:
        from modify_pe import PEModifier
        PEModifier("nope")(None)
    
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    if load_in_4_bit:
        from transformers import BitsAndBytesConfig
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.bfloat16
        )
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=bnb_config if load_in_4_bit else None)
    
    if checkpoint is not None:
        model = PeftModel.from_pretrained(model, checkpoint)
        model = model.merge_and_unload()
        print(f"Successfully merge checkpoint {checkpoint}")
        
    if args.reverse:
        dataset_path = dataset_path.replace("tasks", "tasks_reverse")
    if args.pad:
        dataset_path = dataset_path.replace("tasks", "tasks_pad")
    
    test_dataset = NLDataset(dataset_path, train_or_test="test", num_each=num_each, random_seed=20222943)
    dataloader = DataLoader(test_dataset, batch_size=batchsize, collate_fn=collate_fn, shuffle=False)
    num_return_sequences = 1
    save_path = os.path.join(".", f"generated_{os.path.split(model_name_or_path)[-1]}.jsonl")
    if checkpoint is not None:
        save_path = save_path[:-6] + f"_{os.path.split(checkpoint)[-1]}.jsonl"
        
    if suffix is not None:
        save_path = save_path[:-6] + suffix + '.jsonl'
    
    if continue_ and os.path.exists(save_path):
        with open(save_path, "r") as rf:
            have_generated = len(rf.readlines())
    else:
        have_generated = 0
    
    if one_digit_tokenizer:
        one_digit_converter = json.load(open('one_digit_converter.json')) # dict[str, list[int]]
        black_token_list = [int(token_id) for token_id in one_digit_converter.keys()]
    else:
        black_token_list = None
    
    for tasks, digits, texts, groundtruths in tqdm.tqdm(dataloader):
        if have_generated >= len(tasks):
            have_generated -= len(tasks)
            continue
        if have_generated > 0:
            tasks = tasks[have_generated:]
            digits = digits[have_generated:]
            texts = texts[have_generated:]
            groundtruths = groundtruths[have_generated:]
            have_generated = 0
        
        if one_digit_tokenizer:
            new_inputs_ids_list: list[list[int]] = [] # the final batch
            for text in texts:
                input_ids: list[int] = tokenizer(text)["input_ids"]
                new_ids = []
                for token_id in input_ids:
                    if str(token_id) not in one_digit_converter:
                        new_ids.append(token_id)
                    else:
                        new_ids.extend(one_digit_converter[str(token_id)])
                        print("debug", token_id, one_digit_converter[str(token_id)])
                new_inputs_ids_list.append(new_ids)
                
            # re-pad them as a batch
            inputs = tokenizer.pad({'input_ids': new_inputs_ids_list}, return_tensors="pt", padding=True) 
        else:
            inputs = tokenizer(texts, return_tensors="pt", padding=True)
        # send to device
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        # Generate outputs
        with torch.no_grad():
            bad_words_ids = [black_token_list] * inputs["input_ids"].shape[0] if black_token_list is not None else None
            generated_ids = model.generate(**inputs, max_new_tokens=2*max(digits), num_return_sequences=num_return_sequences, bad_words_ids=bad_words_ids) # N_sen * N_return
        
        generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        # regroup the text to align with input
        generated_texts_batch: list[list[str]] = [generated_text[i:i+num_return_sequences] for i in range(0, len(generated_text), num_return_sequences)]
        
        with open(save_path, "a") as wf:
            for task, digit, generated_text_batch, groundtruth in zip(tasks, digits, generated_texts_batch, groundtruths):
                wf.write(json.dumps({
                    "task": task,
                    "digit": digit,
                    "generated_text": generated_text_batch,
                    "groundtruth": groundtruth
                }) + '\n')
                
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("model", type=str)
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--batchsize", type=int, default=64)
    parser.add_argument("--load_in_4_bit", action="store_true")
    parser.add_argument("--num", type=int, default=100)
    parser.add_argument("--reverse", action="store_true")
    parser.add_argument("--pad", action="store_true")
    parser.add_argument("--one_digit_tokenizer", action="store_true")
    parser.add_argument("--nope", action="store_true")
    parser.add_argument("--suffix", type=str, default=None)
    args = parser.parse_args()
    
    main_test("benchmark/tasks/test.json", ""+args.model, batchsize=args.batchsize, num_each=args.num, continue_=True, checkpoint=args.checkpoint, load_in_4_bit=args.load_in_4_bit, reverse=args.reverse, pad=args.pad, one_digit_tokenizer=args.one_digit_tokenizer, nope=args.nope, suffix=args.suffix)