
from __future__ import annotations
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '0,1,2,3,4,5'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5'

import torch
import json
import time
import random
import transformers
from datetime import datetime

import llm_utils.eval_benchmarks as llm_eval
import llm_utils.load_llm as load_llm
import llm_utils.load_datasets as load_ds

transformers.logging.set_verbosity_error()

choose_llm = 'qwen3-4b'
#choose_llm = 'qwen3-8b'
#choose_llm = 'mistral-7b'
#choose_llm = 'llama3-8b'

eval_benchmark, TEST_OFFSET, SKIP_STEP, num_train = 'mmluPro', 0, 5, 60
#eval_benchmark, TEST_OFFSET, SKIP_STEP, num_train = 'gpqa-main', 0, 1, 100
#eval_benchmark, TEST_OFFSET, SKIP_STEP, num_train = 'gsm8k', 0, 1, 100
#eval_benchmark, TEST_OFFSET, SKIP_STEP, num_train = 'math-500', 0, 1, 100

num_fsSamples = 64
#num_fsSamples = 4

#num_testSample = 160
num_testSample = 1000

num_infSample = 3

print(f"------ script config ------")
print(f"---- LLM: {choose_llm} -- Eval: {eval_benchmark} -- No.FS: {num_fsSamples} -- No.TestS: {num_testSample} -- No.InfS: {num_infSample} ----")

cur_DS, llm_testPromptIds, llm_trainPromptIds = load_ds.get_dataset(eval_benchmark, TEST_OFFSET, SKIP_STEP)
cur_model, cur_tokenizer, layers_range = load_llm.get_llm(choose_llm)


def tmp_fsFunc(test_id, gen_len, max_id):
    fs_ids = list(range(max_id))
    res_list = random.sample(fs_ids, gen_len)
    return res_list

rand_ID = str(round(random.random()*10000))
_filePath = f"stat_records/{choose_llm}_{eval_benchmark}_F{num_fsSamples}_T{num_testSample}_I{num_infSample}_R{rand_ID}.json"
#_filePath = "stat_records/llama3-8b_gpqa-main_F64_T280_I3_R4002_Ct30.json"


res_stat = []
print(f'########## Eval Raw {choose_llm} LLM on {eval_benchmark} ##########')
total_count, corr_count = 0, 0 
print(datetime.now().strftime("%H:%M:%S"))

global_fsSet = [tmp_fsFunc(_, 4, len(llm_trainPromptIds)) for _ in range(num_fsSamples//2)]


#for test_id in llm_testPromptIds[::round(len(llm_testPromptIds)/num_testSample)][:]:
for train_id in llm_testPromptIds[::round(len(llm_testPromptIds)/num_testSample)][:]:
    best1_corr = 0
    res_stat.append([])
    _ss = time.time()
    early_stop = False
    local_fsSet = [tmp_fsFunc(_, 4, len(llm_trainPromptIds)) for _ in range(num_fsSamples//2)]
    for fs_id, fs_val in enumerate(global_fsSet + local_fsSet):
        #cur_messages = llm_eval.get_testPrompt(test_id, eval_benchmark, cur_DS, fs_val)
        cur_messages = llm_eval.get_trainPrompt(train_id, eval_benchmark, cur_DS, fs_val)
        tokenized_inputs = llm_eval.generate_model_inputs(cur_model, cur_tokenizer, cur_messages)

        tmp_corrCount, pred_list, label_val = 0, [], ''
        _ss2 = time.time()
        for _ in range(num_infSample):
            with torch.no_grad(): thinking_outputs, raw_outputs = llm_eval.generate_model_outputs(cur_model, cur_tokenizer, tokenized_inputs)
            #isCorr, _label, _pred = llm_eval.eval_testPrompt(test_id, raw_outputs, eval_benchmark, cur_DS)
            isCorr, _label, _pred = llm_eval.eval_testPrompt(train_id, raw_outputs, eval_benchmark, cur_DS)
            pred_list.append(raw_outputs)
            if label_val == '': label_val = _label
            if isCorr: tmp_corrCount += 1
            if time.time() - _ss2 > num_infSample * 100: break
            torch.cuda.empty_cache()

        #res_stat[-1].append({'test_id': test_id, 'conf': round(tmp_corrCount/num_infSample, 2), 'fs_ids': fs_val, 'pred_list': pred_list, 'label': label_val})
        res_stat[-1].append({'train_id': train_id, 'conf': round(tmp_corrCount/num_infSample, 2), 'fs_ids': fs_val, 'pred_list': pred_list, 'label': label_val})
        if tmp_corrCount >= num_infSample / 2: best1_corr += 1
        if (fs_id == 3 and best1_corr == 4) or (fs_id == 7 and best1_corr == 7) or (fs_id == 15 and best1_corr == 14): early_stop = True; break
        if (fs_id == 31 and best1_corr == 0) or (time.time() - _ss > num_fsSamples * 100): break
        
    total_count += 1
    if early_stop or best1_corr >= num_fsSamples / 2: corr_count += 1
    #if total_count % 10 == 0: print(f'------ TestId: {test_id} --- total_count: {total_count} --- isCorr: {isCorr} --- CurrAcc: {round(corr_count/total_count, 5)} --- TS: {datetime.now().strftime("%H:%M:%S")}')
    if total_count % 10 == 0: 
        with open(_filePath, "w") as file: json.dump(res_stat, file, indent=4)
        print(f"Partial Data with Size={total_count} successfully saved to {_filePath}")

print(datetime.now().strftime("%H:%M:%S"))
print(f"--- Eval Raw {choose_llm} LLM on {eval_benchmark} with Acc: {round(corr_count/total_count, 5)}")

with open(_filePath, "w") as file: json.dump(res_stat, file, indent=4)
print(f"Full Data successfully saved to {_filePath}")

