from add_int_int import Dataset_Generator
import json
import re
from tqdm import tqdm

task_names = {
    'add_Integer_Integer_Integer': 20,
'add_Float_Float_Float': 8,
'add_Fraction_Fraction_Fraction': 4,
'add_easy_Fraction_Fraction_Fraction': 4,
'add_ScientificNotation_ScientificNotation_ScientificNotation': 3,
'sub_Integer_Integer_Integer': 20,
'sub_Float_Float_Float': 8,
'sub_Fraction_Fraction_Fraction': 6,
'sub_ScientificNotation_ScientificNotation_ScientificNotation': 6,
'max_Integer_Integer_Integer': 100,
'max_Float_Float_Float': 60,
'max_Fraction_Fraction_Fraction': 5,
'max_ScientificNotation_ScientificNotation_ScientificNotation': 40,
'max_hard_Integer_Integer_Integer': 30,
'max_hard_Float_Float_Float': 50,
'max_hard_ScientificNotation_ScientificNotation_ScientificNotation': 20,
'multiply_hard_Integer_Integer_Integer': 15,
'multiply_hard_Float_Float_Float': 4,
'multiply_hard_Fraction_Fraction_Fraction': 3,
'multiply_hard_ScientificNotation_ScientificNotation_ScientificNotation': 4,
'multiply_easy_Integer_Integer_Integer': 15,
'multiply_easy_Float_Float_Float': 4,
'multiply_easy_Fraction_Fraction_Fraction': 3,
'multiply_easy_ScientificNotation_ScientificNotation_ScientificNotation': 4,
'digit_max_Integer_Integer_Integer': 20,
'get_digit_Integer_int_int': 40,
'get_digit_Float_int_int': 20,
'length_Integer_none_int': 50,
'floordiv_Integer_Integer_Integer': 6,
'mod_Integer_Integer_Integer': 6,
'mod_easy_Integer_Integer_Integer': 6
}



rfft = {}
import time
with open("test.json", "r") as f:
    data = json.load(f)
cnt = 0

model_path = 'llama3'
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

for task_name, data2 in data.items():
    if task_name not in task_names.keys():
        continue
    print(task_name, end=' ')
    t1 = {}
    Generator = Dataset_Generator(task_name)
    _f = 0
    for digit, tests in data2.items():
        if int(digit) > task_names[task_name]:
            continue
        cnt = 0
        temp = []
        
        for test in tests:
            if cnt > 20:
                break
            sample = Generator.rfft_IO(test)
            final_str = '='.join(sample["input"].split('=')[:-1]) + '\n\n## Response:\n' + sample["output"]
            if '_var1_' in final_str or '_var2_' in final_str or '_var_res_' in final_str:
                print(final_str)
                exit(0)
            # if len(final_str) > 6000:
            #     continue
            tokens = tokenizer.tokenize(final_str)
            token_count = len(tokens)
            # if token_count > 2000:
            #     continue
            cnt+=1
            temp.append(token_count)
        if sum(temp) / len(temp)  > 2000:
            print(digit)
            _f = 1
            break
        
        t1[digit] = temp
    if not _f:
        dd = [int(x) for x in data2.keys()]
        print(max(dd))
    rfft[task_name] = t1
    
# with open("test_rfft.json", "w") as f:
#     json.dump(rfft, f)
