# Example: python mmlu_evaluate.py --max_new_tokens 1 --device cuda:0 --output_dir ./data/llm_frozen_status_v2 --model_loop
import os
import argparse
import torch
import random
import json
from datetime import datetime
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
import random
import gc
import numpy as np
from pathlib import Path
import time
from tqdm import tqdm
import logging
import sys


ALL_MODELS = [
    'Qwen/Qwen3-4B-Instruct-2507',
    'rombodawg/Rombos-LLM-V2.5-Qwen-7b',
    'mistralai/Mistral-7B-Instruct-v0.3',
    'google/gemma-2-9b-it',
    'google/gemma-2-2b-it',
    'Qwen/Qwen2.5-7B-Instruct',
    'Qwen/Qwen2.5-3B-Instruct',
    'TinyLlama/TinyLlama-1.1B-Chat-v1.0',
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B',
    # 'internlm/internlm2_5-7b-chat',              # CRASH RuntimeError: The size of tensor a (607) must match the size of tensor b (606) at non-singleton dimension

    'meta-llama/Llama-3.2-3B-Instruct',
    'meta-llama/Llama-3.2-1B-Instruct',

    'mistralai/Mistral-7B-Instruct-v0.3',
    # 'mistralai/Mixtral-8x7B-Instruct-v0.1',   CAUSED DISK FULL BECAUSE IT'S  ~80GB
    'mistralai/Mistral-Nemo-Instruct-2407',

    'microsoft/Phi-3.5-mini-instruct',
    # 'microsoft/Phi-3-small-8k-instruct',       # CRASH  ValueError: The repository microsoft/Phi-3-small-8k-instruct contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/microsoft/Phi-3-small-8k-instruct .
    'microsoft/Phi-3-medium-4k-instruct',

    'Qwen/Qwen2.5-14B-Instruct',
    'Qwen/Qwen2.5-Coder-7B-Instruct',
    'Qwen/Qwen2.5-Math-7B-Instruct',

    'deepseek-ai/deepseek-coder-6.7b-instruct',

    '01-ai/Yi-1.5-9B-Chat',
    '01-ai/Yi-1.5-6B-Chat',

    'openchat/openchat-3.6-8b-20240522',
    'HuggingFaceH4/zephyr-7b-beta',
    'teknium/OpenHermes-2.5-Mistral-7B',
    'NousResearch/Nous-Hermes-2-Mistral-7B-DPO',

    'Upstage/SOLAR-10.7B-Instruct-v1.0',

    'CohereForAI/aya-23-8B',

    # 'allenai/OLMo-7B-Instruct',  # CRASH ImportError: This modeling file requires the following packages that were not found in your environment: hf_olmo. Run `pip install hf_olmo`

    'Intel/neural-chat-7b-v3-3',

    # 'tiiuae/Falcon3-7B-Base',   # CRASH ValueError: Cannot use chat template functions because tokenizer
    # 'baichuan-inc/Baichuan2-7B-Chat',   # AttributeError: BaichuanTokenizer has no attribute vocab. mmlu_evaluate.py", line  "letter: tokenizer.vocab[letter]"
]


SAVE_NAME = {
    'Qwen/Qwen3-4B-Instruct-2507': 'Qwen3-4B',
    'ministral/Ministral-3b-instruct': 'Ministral-3b',
    'rombodawg/Rombos-LLM-V2.5-Qwen-7b': 'Rombos-Qwen-7b',
    'meta-llama/Llama-3.1-8B-Instruct': 'Llama3.1-8B',
    'mistralai/Mistral-7B-Instruct-v0.3': 'Mistral-7B',
    'microsoft/Phi-3.5-mini-instruct': 'Phi-3.5-mini',
    'google/gemma-2-9b-it': 'Gemma-2-9B',
    'google/gemma-2-2b-it': 'Gemma-2-2B',
    'Qwen/Qwen2.5-7B-Instruct': 'Qwen2.5-7B',
    'Qwen/Qwen2.5-3B-Instruct': 'Qwen2.5-3B',
    'TinyLlama/TinyLlama-1.1B-Chat-v1.0': 'TinyLlama-1.1B',
    'deepseek-ai/DeepSeek-R1-Distill-Qwen-7B': 'DeepSeek-R1-Qwen-7B',
    'internlm/internlm2_5-7b-chat': 'InternLM2.5-7B-Chat',
}
SYSTEM = "You are an LLM, capable of solving difficult questions. You are a helpful assistant."
USER_TMPL = "Question: {question}\n\nChoices:\n{choices}"

def setup_logger(name, level = logging.INFO, extra_file = None):
    Path("./logs").mkdir(parents=True, exist_ok=True)
    logger = logging.getLogger(name)
    for h in logger.handlers[:]:
        logger.removeHandler(h); 
        try: h.close()
        except: pass
    logger.setLevel(level); logger.propagate = False
    fmt = logging.Formatter("%(asctime)s %(levelname)s %(message)s", "%Y-%m-%d %H:%M:%S")

    sh = logging.StreamHandler(sys.stdout); sh.setFormatter(fmt); logger.addHandler(sh)
    fh = logging.FileHandler(f"./logs/{name}.log", encoding="utf-8")
    fh.setFormatter(fmt); logger.addHandler(fh)

    if extra_file:  
        extra_file = Path(f'./logs/{extra_file}.log')
        os.makedirs(extra_file.parent, exist_ok=True)
        fh2 = logging.FileHandler(extra_file.as_posix(), encoding="utf-8")
        fh2.setFormatter(fmt); logger.addHandler(fh2)

    return logger

logger = setup_logger(Path(__file__).stem)

class MyWrapper:
    def __init__(self, model, tokenizer, device, dtype, model_name):
        self.model = model
        self.model_name = model_name
        self.tokenizer = tokenizer
        self.supports_system_prompt = "Gemma" not in model_name
        self.letter_logit_idx = {letter: tokenizer.vocab[letter] for letter in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']}

def build_prompt(tokenizer, task, question, choices, shots=None, system_prompt=True):
    messages = [{"role": "system", "content": SYSTEM}] if system_prompt else []
    FINAL_ANSWER = "Final answer:"
    if shots:
        for s_question, s_choices, s_answer in shots:
            choices_text = "\n".join([f"{chr(65+i)}. {choice}" for i, choice in enumerate(s_choices)])
            messages.append({"role": "user", "content": USER_TMPL.format(question=s_question, choices=choices_text)})
            messages.append({"role": "assistant", "content": f"{FINAL_ANSWER} ({s_answer})"})
    
    choices_text = "\n".join([f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)])
    messages.append({"role": "user", "content": USER_TMPL.format(question=question, choices=choices_text)})
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    prompt += FINAL_ANSWER + '('
    return prompt

def norm(s):
    return " ".join(s.strip().split()).lower()

def extract_answer(text, task_type):
    t = text.strip()
    lower_t = t.lower()
    
    # Look for answer patterns
    for pattern in ["answer:", "the answer is", "answer is", "correct answer is"]:
        anchor_idx = lower_t.rfind(pattern)
        if anchor_idx != -1:
            t = t[anchor_idx + len(pattern):].strip()
            break
    
    # Extract single letter answer (A, B, C, D)
    for char in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']:
        if char in t:
            return char
    # If no letter found, try to extract from parentheses
    if '(' in t and ')' in t:
        t = t[t.rfind('(')+1:]
        t = t[:t.find(')')]
        return t.strip()
    # If still no match, return the first character if it's A-D
    if t and t[0] in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']:
        return t[0]
    return t.strip() if t else ""

def correct(pred, gold):
    p = norm(pred)
    if isinstance(gold, list): return any(p == norm(g) for g in gold)
    return p == norm(gold)

def load_few_shots(ds, i_to_avoid, k=3):
    candidate_i = set(range(len(ds))) - set(i_to_avoid)
    k = min(k, len(candidate_i))
    if k <= 0:
        return []
    shots = []
    r = random.sample(list(candidate_i), k)
    for i in r:
        shots.append((ds[i]['question'], ds[i]['options'], ds[i]['answer']))
    return shots

@torch.no_grad()
def generate(my_wrapper, prompt, device, max_new_tokens=512, temperature=0.0, top_p=1.0, target=None, sample_name=''):
    model, tokenizer = my_wrapper.model, my_wrapper.tokenizer
    enc = tokenizer(prompt, return_tensors="pt").to(device)
    pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    gen = model.generate(
        **enc,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=(temperature > 0),
        pad_token_id=pad_id,
        eos_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        output_hidden_states=True,
        output_scores=True,
    )
    out_ids = gen[0][0][enc["input_ids"].shape[1]:]
    out_logits = gen.scores[-1][-1]  # (vocab, )
    out_logits_softmaxed = F.softmax(out_logits, dim=-1) 

    predited_text = tokenizer.decode(out_ids, skip_special_tokens=True)
    last_output_id = out_ids[-1]
    if target is not None:
        target_id = tokenizer(target, return_tensors="pt")['input_ids'][0][-1]
    output_dict = {
        'sample_name': sample_name,
        'predicted_token': out_ids.item(),
        'predicted_text': predited_text,
        'target_token': target_id.item(),
        'target_text': target,
        'last_token_scores': {
            'logits': {letter: out_logits[my_wrapper.letter_logit_idx[letter]].item() for letter in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']},
            'softmaxed': {letter: out_logits_softmaxed[my_wrapper.letter_logit_idx[letter]].item() for letter in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']},
            'target_logit': out_logits[target_id].item(),
            'target_softmaxed': out_logits_softmaxed[target_id].item(),
            'pred_logit': out_logits[last_output_id].item(),
            'pred_softmaxed': out_logits_softmaxed[last_output_id].item(),
        },
    }
    return predited_text, output_dict

def load_mmlu(task):
    ds = load_dataset('TIGER-Lab/MMLU-Pro', split="test")
    ds =list(ds)
    return ds

def eval_task(my_wrapper, task, device, max_new_tokens, temperature, top_p, output_dir, shots_k=3, verbose_every=1000):
    exs = load_mmlu(task)
    good, total = 0, 0
    output_dicts = []
    save_every = 100
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    old_json_fn = None
    pbar = tqdm(exs)
    for i, ex in enumerate(pbar):
        # if i > 10 and i < 11900:
        #     continue
        shots = load_few_shots(exs, [i], k=shots_k)
        prompt = build_prompt(my_wrapper.tokenizer, task, ex['question'], ex['options'], shots=shots, system_prompt=my_wrapper.supports_system_prompt)
        pred_raw, output_dict = generate(
            my_wrapper=my_wrapper,
            prompt=prompt,
            device=device,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            target=ex['answer'],
            sample_name=f'MMLU_{task}_{i}',
            )
        output_dicts.append(output_dict)
        pred = extract_answer(pred_raw, task)
        ok = correct(pred, ex['answer'])
        total += 1; good += int(ok)
        acc = good / max(total, 1)
        if (i % verbose_every == 0 and i > 0) or (i == len(exs) - 1):
            logger.info(f"[{task}]/[{my_wrapper.model_name}] {i+1}/{len(exs)} accuracy: {acc*100:.2f}%  ({good}/{total})  raw='{pred_raw}' pred='{pred}'  gold='{ex['answer']}'  ok={ok}")
        # if i % 10 == 0:
        #     gc.collect()
        #     torch.cuda.empty_cache()
        pbar.set_description(f"[{task}]/[{my_wrapper.model_name}] accuracy: {acc*100:.2f}%  ({good}/{total})")
        if i % save_every == 0 and i > 0 or (i == len(exs) - 1):
            # print(f"[{task}] accuracy: {acc*100:.2f}%  ({good}/{total})")
            now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            new_json_fn = output_dir / f'MMLU_{my_wrapper.model_name}_{i}_{now}.json'
            with open(new_json_fn, 'w') as f:
                json.dump(output_dicts, f)
            # print(f"Saved to {new_json_fn}")
            if old_json_fn is not None:
                os.remove(old_json_fn)
            old_json_fn = new_json_fn
    return good, total

def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=str)
    ap.add_argument("--model_loop", action='store_true', help="Whether to loop through the model")
    ap.add_argument("--tasks", type=str, default="pro")
    ap.add_argument("--max_new_tokens", type=int, default=128)
    ap.add_argument("--temperature", type=float, default=0.0)
    ap.add_argument("--top_p", type=float, default=1.0)
    ap.add_argument("--shots", type=int, default=3, help="Number of few-shot exemplars to prepend (default: 3)")
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--output_dir", type=str, default="./data/llm_frozen_status_v0")
    args = ap.parse_args()
    return args

def get_model_name(full_name):
    if full_name in SAVE_NAME:
        return SAVE_NAME[full_name]
    else:
        new_name = full_name.split('/')[-1]
        # logger.warning(f"Model name not found in SAVE_NAME: {full_name}, using {new_name}")
        return new_name

def my_wrapper(args):
    DTYPE = torch.bfloat16 if args.device.startswith("cuda") else torch.float32
    logger.info(f"Loading {args.model} on {args.device} ({DTYPE})")
    tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    if 'microsoft' in args.model:
        trust_remote_code = False
    else:
        trust_remote_code = True
    mdl = AutoModelForCausalLM.from_pretrained(args.model, dtype=DTYPE, trust_remote_code=trust_remote_code)
    mdl.to(args.device).eval()
    model_name = get_model_name(args.model)
    return MyWrapper(mdl, tok, device=args.device, dtype=DTYPE, model_name=model_name)

def main(args, mt=None):
    global logger
    model_name = get_model_name(args.model)
    logger = setup_logger(Path(__file__).stem, extra_file=f"{model_name}")
    if args.model in ALL_MODELS:
        model_idx = ALL_MODELS.index(args.model)
        logger.info(f"Current loop running model {model_idx+1}/{len(ALL_MODELS)} ({args.model})")
    tasks = [t.strip() for t in args.tasks.split(",") if t.strip()]
    previous_last_json_fn = list(Path(args.output_dir).glob(f'MMLU_{model_name}_12031_*'))
    if len(previous_last_json_fn) > 0:
        logger.warning(f"Model {model_name} already evaluated {previous_last_json_fn} examples")
        return

    if mt is None:
        mt = my_wrapper(args)


    total_good, total_all = 0, 0
    for t in tasks:
        g, a = eval_task(mt, t, args.device, args.max_new_tokens, args.temperature, args.top_p, shots_k=args.shots, output_dir=args.output_dir)
        total_good += g; total_all += a

    logger.info(f"OVERALL Accuracy: {100*total_good/max(total_all,1):.2f}%  ({total_good}/{total_all})")

def main_model_loop(args):
    for model in ALL_MODELS:
        try:
            args.model = model
            main(args)
            # break
            time.sleep(5)
            gc.collect()
            torch.cuda.empty_cache()
        except Exception as e:
            logger.error(f"Error evaluating {model}: {e}")
            logger.exception(e)
            time.sleep(5)
            gc.collect()
            torch.cuda.empty_cache()
            pass

if __name__ == "__main__":
    args = parse_args()
    if args.model_loop:
        main_model_loop(args)
    else:
        main(args)
