# %%
import os
import argparse


def get_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='qwen/Qwen2.5-1.5B-Instruct')
    parser.add_argument('--model_short_name', type=str, default='qwen_1.5B')
    parser.add_argument('--sample_path', type=str, default='lime_samples')
    parser.add_argument('--chunk_size', type=int, default=1000)
    args,_ = parser.parse_known_args()
    return args


args = get_args()
model_name = args.model_name
model_short_name = args.model_short_name
sample_path = args.sample_path
chunk_size = args.chunk_size
if 'deepseek' in model_short_name:
    chunk_size = max(5000, chunk_size)


import concurrent
import pandas as pd
import torch
from tqdm.auto import tqdm

from utils.llm_predict import OpenAIAPIPredictor

# model_name = 'Qwen/Qwen2.5-1.5B-Instruct'
# model_short_name = 'qwen_1.5B'
llm_predictor = OpenAIAPIPredictor(model_name)

# %%
# Build the prompts using Chat format. We support converting Chat conversations to text for non-Chat models

choices = ["A", "B", "C", "D"]
sys_msg = "The following are multiple choice questions (with answers) about {}."
def create_chat_prompt(sys_msg, question, answers, subject):
    user_prompt = f"{question}\n" + "\n".join([f"{choice}. {answer}" for choice, answer in zip(choices, answers)]) + "\nAnswer:"
    return [
        {"role": "system", "content": sys_msg.format(subject)}, 
        {"role": "user", "content": user_prompt}
    ]

def create_chat_example(question, answers, correct_answer):
    """
    Form few-shot prompts in the recommended format: https://github.com/openai/openai-python/blob/main/chatml.md#few-shot-prompting
    """
    user_prompt = f"{question}\n" + "\n".join([f"{choice}. {answer}" for choice, answer in zip(choices, answers)]) + "\nAnswer:"
    return [
        {"role": "system", "content": user_prompt, "name": "example_user"},
        {"role": "system", "content": correct_answer, "name": "example_assistant"},
    ]

# %%
data_path = "mmlu/data"
subjects = os.listdir('samples_pools')
# %%


# %%
def gen_prompt(sample, few_shot=[]):
    for s in few_shot:
        s[0]['role'] = 'user'
        s[1]['role'] = 'assistant'
    prompt = sample[:-1]
    for s in few_shot:
        prompt += s
    prompt += sample[-1:]
    return prompt


# %%
for subject in sorted(subjects):
    print(f"Subject: {subject}")
    # few shot examples
    dev_df = pd.read_csv(os.path.join(data_path, "dev", subject + "_dev.csv"), names=("Question", "A", "B", "C", "D", "Answer"))
    dev_df["sample"] = dev_df.apply(lambda x: create_chat_example(x["Question"], x[["A", "B", "C", "D"]], x["Answer"]), axis=1)
    few_shot = dev_df["sample"].to_list().copy()

    test_df = pd.read_csv(os.path.join(data_path, "test", subject + "_test.csv"), names=("Question", "A", "B", "C", "D", "Answer"))

    test_df["ideal"] = test_df.Answer

    results_path = os.path.join(sample_path, subject)
    if os.path.exists(os.path.join(results_path, f'{subject}_perturb_{model_short_name}.csv')):
        samples_df = pd.read_csv(os.path.join(results_path, f'{subject}_perturb_{model_short_name}.csv'), sep='\t',index_col=None, keep_default_na=False)
    else:
        samples_df = pd.read_csv(os.path.join(results_path, f'{subject}_perturb.csv'), sep='\t',index_col=None, keep_default_na=False)
        samples_df['logits_A'] = -1000.0
        samples_df['logits_B'] = -1000.0
        samples_df['logits_C'] = -1000.0
        samples_df['logits_D'] = -1000.0
    
    if min(samples_df['logits_A'].tolist()) > -1000.0:
        print("all samples have been processed")
        continue

    if os.path.exists(f'./samples_pools/{subject}/{subject}_cached_{model_short_name}.csv'):
        pool_df = pd.read_csv(f'./samples_pools/{subject}/{subject}_cached_{model_short_name}.csv', sep='\t',index_col=None,keep_default_na=False)
        print("get data from pool")
        sampless = []
        for idx in tqdm(range(max(samples_df['question_index'].tolist())+1)):
            local_samples = samples_df[samples_df['question_index'] == idx]
            local_samples.reset_index(drop=True, inplace=True)
            local_pool = pool_df[pool_df['question_index'] == idx]
            for i in range(len(local_samples)):
                if local_samples.at[i, 'logits_A'] > -1000.0:
                    continue
                if local_samples.at[i,'binary_representation'] in local_pool['binary_representation'].tolist():
                    local_samples.at[i, 'logits_A'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_A'].values[0]
                    local_samples.at[i, 'logits_B'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_B'].values[0]
                    local_samples.at[i, 'logits_C'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_C'].values[0]
                    local_samples.at[i, 'logits_D'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_D'].values[0]
                    
            sampless.append(local_samples)
        samples_df = pd.concat(sampless, ignore_index=True)

    
    samples_df['A'] = [test_df.loc[i,'A'] for i in samples_df['question_index'].to_list()]
    samples_df['B'] = [test_df.loc[i,'B'] for i in samples_df['question_index'].to_list()]
    samples_df['C'] = [test_df.loc[i,'C'] for i in samples_df['question_index'].to_list()]
    samples_df['D'] = [test_df.loc[i,'D'] for i in samples_df['question_index'].to_list()]

    samples_df["input"] = samples_df.apply(lambda x: create_chat_prompt(sys_msg, x["sample_question"], x[["A", "B", "C", "D"]], subject), axis=1)
    
    # 定义预测函数
    def process_row(idx, sample, few_shot, choices):
        if samples_df.loc[idx, 'logits_A'] > -1000.0:
            return None  # 跳过已处理的行
        prompt = gen_prompt(sample, few_shot)
        res = llm_predictor.predict(prompt)  # 假设返回的是 logits 值
        result = {f"logits_{k}": res[i] for i, k in enumerate(choices)}
        return idx, result

    # 并行化处理
    results = []
    # chunk_size = 1000  # 每次处理 1000 行
    # if 'deepseek' in model_short_name:
        # chunk_size = 5000
    with concurrent.futures.ThreadPoolExecutor(max_workers=64) as executor:
        for start_idx in range(0, len(samples_df), chunk_size):
            end_idx = min(start_idx + chunk_size, len(samples_df))
            futures = [
                executor.submit(process_row, idx, samples_df.loc[idx, 'input'], few_shot, choices)
                for idx in range(start_idx, end_idx)
            ]
            flag = False
            for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
                result = future.result()
                if result is not None:  # 只处理非空结果
                    flag = True
                    results.append(result)
            if not flag:
                print(f"All samples in range {start_idx} to {end_idx} have been processed.")
                continue
            # 更新 DataFrame
            for idx, row_result in results:
                for key, value in row_result.items():
                    samples_df.loc[idx, key] = value

            # 每处理完一个批次后保存结果
            samples_df.to_csv(
                os.path.join(results_path, f'{subject}_perturb_{model_short_name}.csv'),
                sep='\t',
                index=None
            )
            print(f"Processed and saved rows {start_idx} to {end_idx - 1}")

    # 最终保存
    samples_df.to_csv(
        os.path.join(results_path, f'{subject}_perturb_{model_short_name}.csv'),
        sep='\t',
        index=None
    )

# %%

# %%
