import torch
import os
import json
from tqdm import tqdm
from utils import Model_Factory,ModelConfig
import torch
import concurrent.futures
import time
from ensemble_generate import ensemble_sevice








llama3_config=dict(
    model_type="llama3_instruct",
    model_path="",
    tokenizer_path="",
    device_id=torch.device('cuda:3' if torch.cuda.is_available() else 'cpu'),

)

mistral_config=dict(
    model_type="mistral_instruct_3.0",
    model_path="",
    tokenizer_path="",
    device_id=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),

)

Qwen_config=dict(
    model_type="qwen2-instruct",
    model_path="",
    tokenizer_path="",
    device_id=torch.device('cuda:1' if torch.cuda.is_available() else 'cpu'),

)

Glm4_config=dict(
    model_type="glm4-instruct",
    model_path="",
    tokenizer_path="",
    device_id=torch.device('cuda:2' if torch.cuda.is_available() else 'cpu'),

)

Gemma_config=dict(
    model_type="Gemma-instruct",
    model_path="",
    tokenizer_path="",
    device_id=torch.device('cuda:4' if torch.cuda.is_available() else 'cpu'),

)


inference_model_lists = []
verify_model_lists = []

llama3_model = Model_Factory(model_config=ModelConfig(**llama3_config))
inference_model_lists.append(llama3_model)
verify_model_lists.append(llama3_model)


mistral_model = Model_Factory(model_config=ModelConfig(**mistral_config))
inference_model_lists.append(mistral_model)
verify_model_lists.append(mistral_model)

Qwen_model = Model_Factory(model_config=ModelConfig(**Qwen_config))
inference_model_lists.append(Qwen_model)
verify_model_lists.append(Qwen_model)

GLM4_model = Model_Factory(model_config=ModelConfig(**Glm4_config))
verify_model_lists.append(GLM4_model)
inference_model_lists.append(GLM4_model)

Gemma_model = Model_Factory(model_config=ModelConfig(**Gemma_config))
inference_model_lists.append(Gemma_model)
verify_model_lists.append(Gemma_model)


dic = {}
test_file_path = "../dataset/generation/en/instruction-response-en.json"
with open(test_file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            dic[data["instruction"]] = data["answer"]






stream_test = False
torch.cuda.empty_cache()

if stream_test:
    each_step_max_generate_token = int(10)
    max_tokens = int(500)

    all_turns = 0
    all_call_model_nums = 0
    

    LLMs = ensemble_sevice(inference_model_lists,verify_model_lists,max_tokens,each_step_max_generate_token)

    model_name = "_".join([i.model_type for i in inference_model_lists])
    outfile = os.path.join("",f"Specfuse_{model_name}-window-{str(each_step_max_generate_token)}.json")


    predicted_question_list = []
    if os.path.exists(outfile):
        with open(outfile, 'r', encoding='utf-8') as test_file:
            predicted_list = test_file.readlines()
        for p in predicted_list:
            p = json.loads(p)
            predicted_question_list.append(p["instruction"])




    num = 0
    with open(outfile, 'a', encoding="utf8") as f:
        for i in tqdm(dic):
     
            
            if i in predicted_question_list:
                continue

          
            result,model_hit_times = LLMs.generate(i,True,False,False)
            item = dict(
                id=num,
                instruction=i,
                pred=result,
                ground_truth=dic[i],
      
        )
            json.dump(item, f)
            f.write('\n')
            torch.cuda.empty_cache()


