# %%
import os
import argparse


def get_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, default='qwen/Qwen2.5-1.5B-Instruct')
    parser.add_argument('--model_short_name', type=str, default='qwen_1.5B')
    parser.add_argument('--sample_path', type=str, default='lime_samples')
    parser.add_argument('--subjects', nargs='+', default=os.listdir('samples_pools'))
    parser.add_argument('--start_idx', type=int, default=0)
    args,_ = parser.parse_known_args()
    return args


args = get_args()
model_name = args.model_name
model_short_name = args.model_short_name
sample_path = args.sample_path


# os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'
import pandas as pd
import torch
from tqdm.auto import tqdm

from utils.llm_predict import QwenPredictor, LLamaPredictor

# model_name = 'Qwen/Qwen2.5-1.5B-Instruct'
# model_short_name = 'qwen_1.5B'
if model_short_name.lower().startswith('qwen'):
    llm_predictor = QwenPredictor(model_name)
elif model_short_name.lower().startswith('llama'):
    llm_predictor = LLamaPredictor(model_name)
else:
    raise ValueError(f"Model {model_short_name} not supported")

# %%
# Build the prompts using Chat format. We support converting Chat conversations to text for non-Chat models

choices = ["A", "B", "C", "D"]
sys_msg = "The following are multiple choice questions (with answers) about {}."
def create_chat_prompt(sys_msg, question, answers, subject):
    user_prompt = f"{question}\n" + "\n".join([f"{choice}. {answer}" for choice, answer in zip(choices, answers)]) + "\nAnswer:"
    return [
        {"role": "system", "content": sys_msg.format(subject)}, 
        {"role": "user", "content": user_prompt}
    ]

def create_chat_example(question, answers, correct_answer):
    """
    Form few-shot prompts in the recommended format: https://github.com/openai/openai-python/blob/main/chatml.md#few-shot-prompting
    """
    user_prompt = f"{question}\n" + "\n".join([f"{choice}. {answer}" for choice, answer in zip(choices, answers)]) + "\nAnswer:"
    return [
        {"role": "system", "content": user_prompt, "name": "example_user"},
        {"role": "system", "content": correct_answer, "name": "example_assistant"},
    ]

# %%
data_path = "mmlu/data"
# subjects = os.listdir('samples_pools')
subjects = args.subjects
# %%


# %%
def gen_prompt(sample, few_shot=[]):
    for s in few_shot:
        s[0]['role'] = 'user'
        s[1]['role'] = 'assistant'
    prompt = sample[:-1]
    for s in few_shot:
        prompt += s
    prompt += sample[-1:]
    return prompt

start_idx = args.start_idx
# %%
encoded_label = {
    x: llm_predictor.tokenizer.encode(x, return_tensors="pt", add_special_tokens=False)[0][0].item()
    for x in ["A", "B", "C", "D"]
}
# %%
for subject in sorted(subjects):
    print(f"Subject: {subject}")
    # few shot examples
    dev_df = pd.read_csv(os.path.join(data_path, "dev", subject + "_dev.csv"), names=("Question", "A", "B", "C", "D", "Answer"))
    dev_df["sample"] = dev_df.apply(lambda x: create_chat_example(x["Question"], x[["A", "B", "C", "D"]], x["Answer"]), axis=1)
    few_shot = dev_df["sample"].to_list().copy()

    test_df = pd.read_csv(os.path.join(data_path, "test", subject + "_test.csv"), names=("Question", "A", "B", "C", "D", "Answer"))

    test_df["ideal"] = test_df.Answer

    results_path = os.path.join(sample_path, subject)
    if start_idx==0 and os.path.exists(os.path.join(results_path, f'{subject}_perturb_{model_short_name}.csv')):
        samples_df = pd.read_csv(os.path.join(results_path, f'{subject}_perturb_{model_short_name}.csv'), sep='\t',index_col=None, keep_default_na=False)
    elif start_idx!=0 and os.path.exists(os.path.join(results_path, f'{subject}_perturb_{model_short_name}_{start_idx}.csv')):
        samples_df = pd.read_csv(os.path.join(results_path, f'{subject}_perturb_{model_short_name}_{start_idx}.csv'), sep='\t',index_col=None, keep_default_na=False)
    else:
        samples_df = pd.read_csv(os.path.join(results_path, f'{subject}_perturb.csv'), sep='\t',index_col=None, keep_default_na=False)
        samples_df['logits_A'] = -1000.0
        samples_df['logits_B'] = -1000.0
        samples_df['logits_C'] = -1000.0
        samples_df['logits_D'] = -1000.0
    
    if os.path.exists(f'./samples_pools/{subject}/{subject}_cached_{model_short_name}.csv'):
        pool_df = pd.read_csv(f'./samples_pools/{subject}/{subject}_cached_{model_short_name}.csv', sep='\t',index_col=None,keep_default_na=False)
        print("get data from pool")
        sampless = []
        for idx in tqdm(range(max(samples_df['question_index'].tolist())+1)):
            local_samples = samples_df[samples_df['question_index'] == idx]
            local_samples.reset_index(drop=True, inplace=True)
            local_pool = pool_df[pool_df['question_index'] == idx]
            for i in range(len(local_samples)):
                if local_samples.at[i, 'logits_A'] > -1000.0:
                    continue
                if local_samples.at[i,'binary_representation'] in local_pool['binary_representation'].tolist():
                    local_samples.at[i, 'logits_A'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_A'].values[0]
                    local_samples.at[i, 'logits_B'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_B'].values[0]
                    local_samples.at[i, 'logits_C'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_C'].values[0]
                    local_samples.at[i, 'logits_D'] = local_pool[local_pool['binary_representation'] == local_samples.loc[i,'binary_representation']]['logits_D'].values[0]
                    
            sampless.append(local_samples)
        samples_df = pd.concat(sampless, ignore_index=True)

    
    samples_df['A'] = [test_df.loc[i,'A'] for i in samples_df['question_index'].to_list()]
    samples_df['B'] = [test_df.loc[i,'B'] for i in samples_df['question_index'].to_list()]
    samples_df['C'] = [test_df.loc[i,'C'] for i in samples_df['question_index'].to_list()]
    samples_df['D'] = [test_df.loc[i,'D'] for i in samples_df['question_index'].to_list()]

    samples_df["input"] = samples_df.apply(lambda x: create_chat_prompt(sys_msg, x["sample_question"], x[["A", "B", "C", "D"]], subject), axis=1)

    pbar = tqdm(range(start_idx, len(samples_df)), dynamic_ncols=True)
    for idx in pbar:
        if samples_df.loc[idx,'logits_A'] > -1000.0:
            continue
        sample = samples_df.loc[idx,'input']
        prompt = gen_prompt(sample, few_shot)
        res = llm_predictor.predict(prompt)
        for k,v in encoded_label.items():
            samples_df.loc[idx, f'logits_{k}'] = res.logits[0][0][v].item()
        # for choice, logits in zip(choices, res.logits[0][0]):
            # samples_df.loc[idx, f'logits_{choice}'] = logits.item()
        if idx % 10000 == 0:
            if start_idx==0:
                samples_df.to_csv(os.path.join(results_path, f'{subject}_perturb_{model_short_name}.csv'), sep='\t', index=None)
            else:
                samples_df.to_csv(os.path.join(results_path, f'{subject}_perturb_{model_short_name}_{start_idx}.csv'), sep='\t', index=None)
    if start_idx==0:
        samples_df.to_csv(os.path.join(results_path, f'{subject}_perturb_{model_short_name}.csv'), sep='\t', index=None)
    else:
        samples_df.to_csv(os.path.join(results_path, f'{subject}_perturb_{model_short_name}_{start_idx}.csv'), sep='\t', index=None)
    


# %%

# %%
