import torch
import os
import json
from tqdm import tqdm
from utils import Model_Factory,ModelConfig
import torch
import concurrent.futures
import time
from ensemble_generate import ensemble_sevice








mistral_config=dict(
    model_type="mistral_instruct_3.0",
    model_path="",
    tokenizer_path="",
    device_id=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),

)

Qwen_config=dict(
    model_type="qwen2-instruct",
    model_path="",
    tokenizer_path="",
    device_id=torch.device('cuda:1' if torch.cuda.is_available() else 'cpu'),

)

Glm4_config=dict(
    model_type="glm4-instruct",
    model_path="",
    tokenizer_path="",
    device_id=torch.device('cuda:2' if torch.cuda.is_available() else 'cpu'),

)

Gemma_config=dict(
    model_type="Gemma-instruct",
    model_path="t",
    tokenizer_path="",
    device_id=torch.device('cuda:3' if torch.cuda.is_available() else 'cpu'),

)



inference_model_lists = []
verify_model_lists = []



mistral_model = Model_Factory(model_config=ModelConfig(**mistral_config))
inference_model_lists.append(mistral_model)
verify_model_lists.append(mistral_model)

Qwen_model = Model_Factory(model_config=ModelConfig(**Qwen_config))
inference_model_lists.append(Qwen_model)
verify_model_lists.append(Qwen_model)

GLM4_model = Model_Factory(model_config=ModelConfig(**Glm4_config))
verify_model_lists.append(GLM4_model)
inference_model_lists.append(GLM4_model)

Gemma_model = Model_Factory(model_config=ModelConfig(**Gemma_config))
inference_model_lists.append(Gemma_model)
verify_model_lists.append(Gemma_model)















dic = {}

stream_test = True

if stream_test:

    each_step_max_generate_token = int(10)
    max_tokens = int(1000)
    LLMs = ensemble_sevice(inference_model_lists,verify_model_lists,max_tokens,each_step_max_generate_token)

    model_name = "_".join([i.model_type for i in inference_model_lists])
    print(model_name)
    outfile = os.path.join("",f"Specfuse_{model_name}-window-{str(each_step_max_generate_token)}.json")
    print(outfile)
    num = 0  


    #detect duplicate data

    predicted_question_list = []

    if os.path.exists(outfile):
        with open(outfile, 'r', encoding='utf-8') as test_file:
            predicted_list = test_file.readlines()
        for p in predicted_list:
            p = json.loads(p)
            predicted_question_list.append(p["instruction"])

    
    test_file = "../dataset/generation/cn/COIG-CQIA-cn.jsonl"

    with open(test_file, 'r') as r:  
        # with open(outfile_ascii, 'w', encoding="utf8") as f1:  
        with open(outfile, 'a', encoding="utf8") as f:
                
            for line in tqdm(r):
                
                data = json.loads(line)
                num += 1

                if data["instruction"] in predicted_question_list:
                    continue


                result = LLMs.generate(data["instruction"],True,False)
            
                item = dict(
                    id=num,
                    instruction=data["instruction"],
                    pred=result,
                    ground_truth=data["output"],
                    source=data["source"]
                )

  

                json.dump(item, f,ensure_ascii=False)
                f.write('\n')


                torch.cuda.empty_cache()
                

else:
    each_step_max_generate_token = int(10)
    max_tokens = int(1000)
    LLMs = ensemble_sevice(inference_model_lists,verify_model_lists,max_tokens,each_step_max_generate_token)

    num = 0
    test_file = ""
    with open(test_file, 'r') as r:  

        for line in tqdm(r):

            data = json.loads(line)
            result = LLMs.generate(data["instruction"],True,True)
            
                












