# %%
import os
import argparse



def get_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='qwen/Qwen2.5-1.5B-Instruct')
    parser.add_argument('--model_short_name', type=str, default='qwen_1.5B')
    parser.add_argument('--sample_path', type=str, default='lime_samples')
    args,_ = parser.parse_known_args()
    return args

args = get_args()
model_name = args.model_name
model_short_name = args.model_short_name
sample_path = args.sample_path

# Print args
print(f"Model Name: {model_name}")
print(f"Model Short Name: {model_short_name}")
print(f"Sample Path: {sample_path}")



# %%
import torch
from openai import OpenAI
import numpy as np
import sys
import pandas as pd

# %%
data = pd.read_json(f"nq/nq-dev.jsonl", lines=True)

# %%

class OpenAIAPIPredictor:
    def __init__(self, model_name='gpt-4o-2024-11-20', **kwargs):
        super().__init__(**kwargs)
        self.model_name = model_name
        self.model = OpenAI()


    def predict(self, text, **kwargs):
        messages = [
        {"role": "system", "content": "You are a helpful assistant, answer the question briefly within 10 words. You will get penalty if you answer too long."},
        {"role": "user", "content": text}
        ]
        # print(messages)
        for _ in range(10):
            try:
                response = self.model.chat.completions.create(
                    model=self.model_name,
                    messages=messages,
                    temperature=1e-5,
                    max_tokens=20,
                    frequency_penalty=0.0,
                    presence_penalty=0.0,
                    logprobs=None,
                )
                return response.choices[0].message.content
                break
            except :
                pass
        if response is None:
            print(f"Error in response for text: {messages}", file=sys.stderr)
            return None
        else:
            print(f"Error in response for text: {messages}, and the response is {response}", file=sys.stderr)
            return None
        



# %%
predictor = OpenAIAPIPredictor(model_name=model_name)

# %%

samples_df = pd.read_csv(f'./{sample_path}/nq_perturb.csv', sep='\t', index_col=None, keep_default_na=False, dtype={'binary_representation': str})
samples_df

# %%
samples_df['Answer'] = ""

# %%
if os.path.exists(f'./{sample_path}/nq_perturb_{model_short_name}.csv'):
    samples_df = pd.read_csv(f'./{sample_path}/nq_perturb_{model_short_name}.csv', sep='\t', index_col=None, keep_default_na=False, dtype={'binary_representation': str})

# %%
from tqdm.auto import tqdm

from tqdm import tqdm
import pandas as pd
from concurrent.futures import ProcessPoolExecutor
import concurrent

def process_row(idx, row, predictor):
    """处理单行数据的函数"""
    if row['Answer'] != "":
        return None  # 跳过已有答案的行
    
    answer = predictor.predict(row['sample_question'])
    return (idx, answer)

chunk_size = 10000
max_workers = 64  # 设置最大线程数

results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
    # 分批处理
    for start_idx in range(0, len(samples_df), chunk_size):
        end_idx = min(start_idx + chunk_size, len(samples_df))
        
        # 提交任务
        futures = [
            executor.submit(process_row, idx, samples_df.loc[idx], predictor)
            for idx in range(start_idx, end_idx)
        ]
        
        # 处理完成的任务
        for future in tqdm(concurrent.futures.as_completed(futures), 
                            total=len(futures),
                            desc=f"Processing rows {start_idx}-{end_idx-1}"):
            result = future.result()
            if result is not None:
                results.append(result)
        
        # 更新DataFrame
        for idx, answer in results:
            samples_df.loc[idx, 'Answer'] = answer
        
        # 每处理完一个批次后保存结果
        save_path = os.path.join(sample_path, f'nq_perturb_{model_short_name}.csv')
        samples_df.to_csv(save_path, sep='\t', index=False)
        print(f"Saved results for rows {start_idx} to {end_idx-1}")
        
        results.clear()

# 最终保存
final_save_path = os.path.join(sample_path, f'nq_perturb_{model_short_name}.csv')
samples_df.to_csv(final_save_path, sep='\t', index=False)
print("All processing completed. Final results saved.")

