import argparse
import json
import os
import time
from generate import generate
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from math_reward import last_boxed_only_string, remove_boxed

device = 'cuda'
model = AutoModel.from_pretrained(
    "GSAI-ML/LLaDA-8B-Instruct",
    trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

TASKS = [
        'abstract_algebra',
        'anatomy',
        'astronomy',
        'business_ethics',
        'clinical_knowledge',
        'college_biology',
        'college_chemistry',
        'college_computer_science',
        'college_mathematics',
        'college_medicine',
        'college_physics',
        'computer_security',
        'conceptual_physics',
        'econometrics',
        'electrical_engineering',
        'elementary_mathematics',
        'formal_logic',
        'global_facts',
        'high_school_biology',
        'high_school_chemistry',
        'high_school_computer_science',
        'high_school_european_history',
        'high_school_geography',
        'high_school_government_and_politics',
        'high_school_macroeconomics',
        'high_school_mathematics',
        'high_school_microeconomics',
        'high_school_physics',
        'high_school_psychology',
        'high_school_statistics',
        'high_school_us_history',
        'high_school_world_history',
        'human_aging',
        'human_sexuality',
        'international_law',
        'jurisprudence',
        'logical_fallacies'
        'machine_learning',
        'management',
        'marketing',
        'medical_genetics',
        'miscellaneous',
        'moral_disputes',
        'moral_scenarios',
        'nutrition',
        'philosophy',
        'prehistory',
        'professional_accounting',
        'professional_law',
        'professional_medicine',
        'professional_psychology',
        'public_relations',
        'security_studies', 
        'sociology',
        'us_foreign_policy',
        'virology',
        'world_religions']

choices = ["A", "B", "C", "D"]

def compute_metric(output_filename):
    with open(output_filename, 'r') as f:
        run_results = json.load(f)
    total_acc = 0
    total_num = 0
    for task in run_results:
        acc = 0
        pred_answers = run_results[task]['pred_answers']
        gold_answers = run_results[task]['gold_answers']
        for pred, gold in zip(pred_answers, gold_answers):
            if pred == gold: acc += 1
        print("ACC-%s: %.4f" % (task, acc/len(gold_answers)))
        total_acc += acc
        total_num += len(gold_answers)
    print("ACC-all: %.4f" % (total_acc/total_num))


def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt

def gen_prompt(train_df, subject, k=-1):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject))
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)
    
    prompt += "\nLet's think step by step and output the final answer within \\boxed{}.\n"
    return prompt

def prepare_input(prompts):
    input_tokens = tokenizer(prompts[0], return_tensors="pt")["input_ids"]

    return input_tokens

def batch_split(prompts, batch_num):
    batch_prompts = []
    mini_batch = []
    for prompt in prompts:
        mini_batch.append(prompt)
        if len(mini_batch) == batch_num:
            batch_prompts.append(mini_batch)
            mini_batch = []
    if len(mini_batch) != 0:
        batch_prompts.append(mini_batch)
    return batch_prompts

def batch_infer(prompts, remasking):
    batch_size = 1
    answers = []
    for batch_input in tqdm(batch_split(prompts, batch_size)):
        encode_inputs = prepare_input(batch_input)
        outputs, _ = generate(model, encode_inputs, tokenizer, remasking=remasking)
        ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
        ans = remove_boxed(last_boxed_only_string(ans))
        answers.extend(ans)
    # answers = [answer[-1] for answer in answers]
    return answers

def main(data_dir: str):
    
    run_results = {}
    for remasking in ["bst", "low_confidence", "binary_search", "auto_regression"]:
        for task in TASKS:
            output_filename = 'results/mmlu_instruct_%s.json' % (remasking)
            print('Testing %s ...' % task)
            records = []
            dev_df = pd.read_parquet(os.path.join(data_dir, task, "dev-00000-of-00001.parquet"))[:5]
            test_df = pd.read_parquet(os.path.join(data_dir, task, "test-00000-of-00001.parquet"))
            for i in range(test_df.shape[0]):
                k = 1
                prompt_end = format_example(test_df, i, include_answer=False)
                train_prompt = gen_prompt(dev_df, task, k)
                prompt = train_prompt + prompt_end
                while len(tokenizer.tokenize(prompt)) + 1> 2048: # bos token
                    prompt_split = prompt.split("\n\n")
                    prompt_split.pop(1)
                    prompt = '\n\n'.join(prompt_split)
                label = test_df.iloc[i, test_df.shape[1]-1]
                records.append({'prompt':prompt, 'answer':label})
    
            pred_answers = batch_infer([record['prompt'] for record in records], remasking)
            gold_answers = [record['answer'] for record in records]
            run_results[task] = {'pred_answers':pred_answers, 'gold_answers':gold_answers}
        with open(output_filename, 'w') as f:
            json.dump(run_results, f, ensure_ascii=False, indent=2)
    
    compute_metric(output_filename)
    
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='data/')
    args = parser.parse_args()
    main(r"/home/jovyan/.cache/huggingface/hub/datasets--cais--mmlu/snapshots/c30699e8356da336a370243923dbaf21066bb9fe")