import os
import threading
from vllm import ModelRegistry
# from vllm_qwen3_gate import Qwen3Gating
# ModelRegistry.register_model("Qwen3Gating", Qwen3Gating)
from vllm import LLM, SamplingParams
from modelscope import AutoTokenizer
import pandas as pd
import sys
from datasets import load_dataset

def prepare_include(tokenizer=None, think=False):
    def get_prefix(lang):
        answer_trans = {
            'Arabic': 'إجابة',
            'Hebrew': 'תשובה',
            'Greek': 'Απάντηση',
            'Korean': '답변',
            'Russian': 'Ответ',
            'Vietnamese': 'Câu trả lời'
        }
        # return f"أجب عن السؤال التالي من خلال اختيار الإجابة الصحيحة. يجب أن تكون السطر الأخير من إجابتك بالتنسيق التالي: 'الإجابة: $حرف' (بدون علامات الاقتباس) حيث يكون الحرف أحد الحروف ABCD. فكر خطوة بخطوة قبل الإجابة."
        return f"Answer the following multiple choice question. The last line of your response should be of the following format: '{answer_trans[lang]}: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering, reply in {lang}."
    langs = ['Arabic', 'Hebrew', 'Greek', 'Korean', 'Russian', 'Vietnamese']
    # langs = ['Arabic']
    answers = []
    prompts = []
    for lang in langs:
        lang_ds = load_dataset("./cs_gate_train/eval/include-base-44", lang)
        for row in lang_ds['test']:
            prompt = get_prefix(lang) + row['question'] + '\n\n' + f"A) {row['option_a']}\n" + f"B) {row['option_b']}\n" + f"C) {row['option_c']}\n" + f"D) {row['option_d']}"
            if tokenizer is not None:
                prompt = tokenizer.apply_chat_template([
                    {'role': 'user', 'content': prompt}
                ], add_generation_prompt=True, tokenize=False)
            if think:
                prompt += '<think>'
            prompts.append(prompt)
            answers.append(row['answer'])
    return prompts, answers

def prepare_flores(tokenizer, think=False):
    df = pd.read_json('./cs_gate_train/open_source_dataset/flores_plus_devtest.jsonl', lines=True)
    prompts = []
    answers = []

    if tokenizer is not None:
        def format_prompt(prompt):
            return tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], add_generation_prompt=True, tokenize=False)

    iso_15924_arr = ['Cyrl', 'Hebr', 'Hang', 'Grek', 'Arab', 'Thai']
    for lang in iso_15924_arr:
        lang_samples = df[df['iso_15924'] == lang].sample(n=500, random_state=42)
        if tokenizer is not None:
            lang_prompts = lang_samples.apply(lambda row: format_prompt(row['zh_query']), axis=1)
        else:
            lang_prompts = lang_samples.apply(lambda row: row['zh_query'], axis=1)
        lang_answers = lang_samples.apply(lambda row: row['src_zh'], axis=1)
        prompts.extend(lang_prompts)
        answers.extend(lang_answers)
    glottocode_arr = ['viet1252', 'stan1290']
    for lang in glottocode_arr:
        lang_samples = df[df['glottocode'] == lang].sample(n=500, random_state=42)
        if tokenizer is not None:
            lang_prompts = lang_samples.apply(lambda row: format_prompt(row['zh_query']), axis=1)
        else:
            lang_prompts = lang_samples.apply(lambda row: row['zh_query'], axis=1)
        lang_answers = lang_samples.apply(lambda row: row['src_zh'], axis=1)
        prompts.extend(lang_prompts)
        answers.extend(lang_answers)
    return prompts, answers

model_dict = {
    '30b-nogate': '/cpfs01/user/jiawei.lyt/ckpt/verl_checkpoints/lyt-rl-gen/qwen3-tpp-nothink-0721-distilled-data0706-recitex1-bothtrans-mixlangx2-GenRM-32B-sentcs-GSPO-ref-turbopp-LENGTH_FLIP_THRESHOLD1.3-LENGTH_FLIP_PROB0.75-REF_ANSWER_POSITION-A-expert-12k_bs512_minibs128_n8/global_step_60/actor_hf',
    '30b-gate-inner': './codeswitch/cs_gate/models/gate-qwen3-20k_95p_flores_2025-07-28-09:02:44',
    '30b-gate-inner-think': './cs_gate_train/models/30b_think-gate-qwen3-controlfix-20k_95p_flores_2025-08-19-05:21:18_plugged',
    '30b-gate-oss-old': './cs_gate_train/models/opensource-turbo-nothink-gate-qwen3-controlfix-20k_95p_flores_2025-08-18-11:06:45_plugged',
    '30b-gate-oss': './cs_gate_train/models/opensource-turbo-nothink-gate-qwen3-controlfix-20k_95p_flores_2025-08-22-15:21:43_plugged',
    'llama-8b-oss': './cs_gate_train/models/gate-llama3-8b-20k_95p_norm_2025-08-25-03:50:13_plugged',
    'llama-8b-oss_nonorm': './cs_gate_train/models/gate-llama3-8b-20k_95p_2025-08-24-11:41:02_plugged',
    'llama-8b-nogate': './cs_gate_train/models/llama3-8b',
    'gemma-4b-oss': './cs_gate_train/models/gate-gemma3-4b-20k_95p_2025-08-25-07:50:43_plugged',
    'gemma-4b-oss-nonorm': './cs_gate_train/models/gate-gemma3-4b-nonorm-20k_95p_2025-08-26-12:21:17_plugged',
    'gemma-4b-nogate': './cs_gate_train/models/gemma3-4b',
    'gemma-12b-oss': './cs_gate_train/models/gate-gemma3-12b-20k_95p_2025-08-26-02:18:32_plugged',
    'gemma-12b-nogate': './cs_gate_train/models/gemma3-12b',
    'aya-expanse-8b-nogate': './cs_gate_train/models/aya-expanse-8b',
    'gpt-oss-20b-nogate': './cs_gate_train/models/gpt-oss-20b',
    'seed-oss-36b': './models/seed-oss-36b'
}



# SET VLLM_USE_V1=0 before run
def main():
    # print(os.environ['VLLM_USE_V1'])
    # assert(os.environ['VLLM_USE_V1'] == '0')
    model_name = sys.argv[1]
    should_think = 'think' in model_name
    if should_think:
        # max_tokens = 32768
        max_tokens = 8000
    else:
        max_tokens = 8000
    print("should_think", should_think)
    
    # model_name = '30b-nogate'
    model_path = model_dict[model_name]
    os.environ['TOK_PATH'] = model_path
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    prompts, answers = prepare_include(tokenizer, should_think)
    llm = LLM(model=model_path, dtype="bfloat16", max_model_len=8192, tensor_parallel_size=2)
    if 'gemma' in model_name:
        sampling_params = SamplingParams(temperature=1, top_p=0.95, top_k=64, max_tokens=max_tokens)
    else:
        sampling_params = SamplingParams(temperature=0.7, top_p=0.8, top_k=20, max_tokens=max_tokens)
    outputs = llm.generate(prompts, sampling_params)
    res_arr = []
    for output, answer in zip(outputs, answers):
        prompt = output.prompt
        generated_text = output.outputs[0].text
        # print("Prompt:", [prompt])
        print("Response:", [generated_text])
        res_arr.append({
            'query': prompt,
            'query_response': generated_text,
            'answer': answer
        })
    res_df = pd.DataFrame(res_arr)
    from datetime import datetime
    now = datetime.now()
    timestamp = now.strftime("%Y-%m-%d-%H:%M:%S")
    out_path = f'./data/{model_name}_include_res_{timestamp}.jsonl'
    res_df.to_json(out_path, orient='records', lines=True)
    print("Saved to:", out_path)

def eval_flores():
    print(os.environ['VLLM_USE_V1'])
    assert(os.environ['VLLM_USE_V1'] == '0')
    model_name = sys.argv[1]
    should_think = 'think' in model_name
    if should_think:
        # max_tokens = 32768
        max_tokens = 500
    else:
        max_tokens = 500
    print("should_think", should_think)
    
    # model_name = '30b-nogate'
    model_path = model_dict[model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    prompts, answers = prepare_flores(tokenizer, should_think)
    llm = LLM(model=model_path, dtype="bfloat16", max_model_len=8192, tensor_parallel_size=2)
    sampling_params = SamplingParams(temperature=0.7, top_p=0.8, top_k=20, max_tokens=max_tokens)
    outputs = llm.generate(prompts, sampling_params)

    res_arr = []
    for output, answer in zip(outputs, answers):
        prompt = output.prompt
        generated_text = output.outputs[0].text
        # print("Prompt:", [prompt])
        print("Response:", [generated_text])
        res_arr.append({
            'query': prompt,
            'query_response': generated_text,
            'answer': answer
        })
    res_df = pd.DataFrame(res_arr)
    from datetime import datetime
    now = datetime.now()
    timestamp = now.strftime("%Y-%m-%d-%H:%M:%S")
    out_path = f'./data/{model_name}_flores_res_{timestamp}.jsonl'
    res_df.to_json(out_path, orient='records', lines=True)
    print("Saved to:", out_path)

def main2():
    print(os.environ['VLLM_USE_V1'])
    assert(os.environ['VLLM_USE_V1'] == '0')
    # model_name = '30b-nogate'
    model_name = sys.argv[1]
    model_path = model_dict[model_name]
    llm = LLM(model=model_path, dtype="bfloat16", max_model_len=8192, tensor_parallel_size=2)
    sampling_params = SamplingParams(temperature=0.7, top_p=0.8, top_k=20, max_tokens=8000)
    prompt = """<|im_start|>user\nAnswer the following multiple choice question. The last line of your response should be of the following format: 'إجابة: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering, reply in Arabic. \nيمكن تحديد وتوثيق إجراءات الضبط العامة واج ارءات ضبط أنظمة التشغيل اإللكتروني باستخدام:\n\nA) الرسوم البيانية.\nB) قوائم االستبيان.\nC) قوائم التأكيد.\nD) كل ما سبق.<|im_end|>\n<|im_start|>assistant\nلنحلل الخيارات واحدة تلو الأخرى:\n\n1) الرسوم البيانية:\n- يمكن استخدامها لتمثيل إجراءات الضبط بشكل مرئي.\n- تساعد في توضيح التدفق المنطقي للإجراءات.\n- لكنها لا تكفي وحدها لتوثيق جميع التفاصيل.\n\n2) قوائم الاستبيان:\n- تُستخدم لجمع المعلومات من الأفراد أو الأنظمة.\n- تساعد في التحقق من وجود إجراءات معينة.\n- لكنها ليست الأداة الأساسية لتوثيق الإجراءات.\n\n3) قوائم التأكيد:\n- تُستخدم بشكل شائع في التحقق من تنفيذ الإجراءات.\n- تضمن أن كل خطوة تم إنجازها بشكل صحيح.\n- تعتبر أداة فعالة في توثيق وتطبيق إجراءات الضبط.\n\n4) كل ما سبق:\n- جميع الأدوات المذكورة (الرسوم البيانية، قوائم الاستبيان، قوائم التأكيد) يمكن استخدامها بشكل متكامل لتحديد وتوثيق إجراءات الضبط.\n- في الممارسة العملية، تُستخدم هذه الأدوات معًا لضمان الشمولية والدقة.\n\nإذًا، جميع الخيارات صحيحة ومتراكبة في سياق توثيق إجراءات الضبط.\n\nإجابة:"""
    prompts = [prompt] * 10
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        # print("Prompt:", [prompt])
        print("Response:", [generated_text])

if __name__ == '__main__':
    # main()
    eval_flores()
    # prepare_include()