import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, pipeline
from tqdm import tqdm
import torch
import argparse


MULTIPLE_CHOICE_TASKS = [
        'temporal_sequences', 'disambiguation_qa', 'date_understanding', 'tracking_shuffled_objects_three_objects', 'penguins_in_a_table', 
        'geometric_shapes', 'snarks', 'ruin_names', 'tracking_shuffled_objects_seven_objects', 'tracking_shuffled_objects_five_objects', 
        'logical_deduction_three_objects', 'hyperbaton', 'logical_deduction_five_objects', 'logical_deduction_seven_objects', 'movie_recommendation', 
        'salient_translation_error_detection', 'reasoning_about_colored_objects', 
]
FREE_FORM_TASKS = [
        'multistep_arithmetic_two', 'navigate', 'dyck_languages', 'word_sorting', 'sports_understanding', 
        'boolean_expressions', 'object_counting', 'formal_fallacies', 'causal_judgement', 'web_of_lies', 
]

def extract_ans(ans, mode):
    ans_line = ans.split('answer is')
    # Expect to see 'answer is'. If not return whole string
    if len(ans_line) == 1:
        return ""
    else:
        ans = ans_line[1].strip()
    ans = ans.split('.')[0]
    if mode == 'multiple_choice':
        options = ['(A)', '(B)', '(C)', '(D)', '(E)', '(F)', '(G)', '(H)', '(I)', '(J)', '(K)', '(L)', '(M)', '(N)', '(O)', '(P)', '(Q)', '(R)', '(S)', '(T)', '(U)', '(V)', '(W)', '(X)', '(Y)', '(Z)']
        for option in options:
            if option in ans:
                ans = option[1]
                break
        return ans
    elif mode == 'free_form':
        return ans


parser = argparse.ArgumentParser(description="Run model with configurable paths and ratios")
parser.add_argument("--big_model_path", type=str, required=True, help="Path to the big model (e.g., Qwen2-7B)")
parser.add_argument("--small_model_path", type=str, required=True, help="Path to the small model (e.g., Qwen2-0.5B)")
parser.add_argument("--heavy_budget_ratio", type=float, default=0.2, help="Heavy budget ratio")
parser.add_argument("--recent_budget_ratio", type=float, default=0.2, help="Recent budget ratio")
parser.add_argument("--compensate_budget_ratio", type=float, default=0.2, help="Compensate budget ratio")
parser.add_argument("--bbh_path", type=str, required=True, help="Path to BBH dataset")
parser.add_argument("--prompt_path", type=str, required=True, help="Path to prompt file")
parser.add_argument("--model_series", type=str, required=True, help="qwen or llama series")
args = parser.parse_args()

# 加载模型和数据
config = AutoConfig.from_pretrained(args.big_model_path)
tokenizer = AutoTokenizer.from_pretrained(args.big_model_path)

small_model = AutoModelForCausalLM.from_pretrained(
    args.small_model_path,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="eager"
)

big_model = AutoModelForCausalLM.from_pretrained(
    args.big_model_path,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="eager"
)


config.heavy_budget_ratio = args.heavy_budget_ratio
config.recent_budget_ratio = args.recent_budget_ratio
config.compensate_budget_ratio = args.compensate_budget_ratio

if args.model_series == "qwen":
    from modify_qwen import enable_qwen2_small_model
    enable_qwen2_small_model(small_model, big_model, config)
elif args.model_series == "llama":
    from modify_llama import enable_llama_small_model
    enable_llama_small_model(small_model, big_model, config)   


bbh_path = args.bbh_path
subsets = ['boolean_expressions', 'causal_judgement', 'date_understanding', 'disambiguation_qa', 'dyck_languages', 'formal_fallacies', 'geometric_shapes', 'hyperbaton', 'logical_deduction_five_objects', 'logical_deduction_seven_objects', 'logical_deduction_three_objects', 'movie_recommendation', 'multistep_arithmetic_two', 'navigate', 'object_counting', 'penguins_in_a_table', 'reasoning_about_colored_objects', 'ruin_names', 'salient_translation_error_detection', 'snarks', 'sports_understanding', 'temporal_sequences', 'tracking_shuffled_objects_five_objects', 'tracking_shuffled_objects_seven_objects', 'tracking_shuffled_objects_three_objects', 'web_of_lies', 'word_sorting']
res = dict()
generator = pipeline("text-generation", model=small_model, tokenizer=tokenizer, max_new_tokens=500, do_sample=False)

for subset in subsets:
    acc = 0
    prompt_file = args.prompt_path + f"/{subset}.txt"
    c_prompt = open(prompt_file).read()
    if subset in MULTIPLE_CHOICE_TASKS:
        mode = 'multiple_choice'
    else:
        mode = 'free_form'

    bbh = load_dataset(bbh_path, subset)
    lenth = len(bbh['test']['input'])

    for q, a in tqdm(zip(bbh['test']['input'][:], bbh['test']['target'][:]), total=len(bbh['test']['input'][:])):
        instruction = (
            "Please reference the following examples to answer the question.\n"
        )
        prompt = (
            instruction
            + c_prompt
            + "\n\nQ: "
            + q
            + "\nA: "
        )
        prompt_length = len(prompt)
        generated_answer = generator(prompt)[0]['generated_text'][prompt_length:].lstrip() + "\n\n"
    
        ans_ = extract_ans(generated_answer, mode)
        if mode == 'multiple_choice':
            ans = a[1]
        elif mode == 'free_form':
            ans = a
        with open(f"bbh_outputs/test.txt", "a") as fd:
            fd.write('%s\nA_model:\n%s\nA_target:\n%s\n\n' % (q, generated_answer, ans))

        if ans in ans_ or ans in ans_.lower():
            acc += 1

    print('%s acc %.4f' % (subset, acc / len(bbh['test']['input'][:])))
    res[subset] = acc / len(bbh['test']['input'][:])
    
total_count = 0
total_score = 0
for _,v in res.items():
    total_count += 1
    total_score += v
print(f"Average Score {total_score/total_count}")