
import time
import types

################## general model generation ##################

def generate_model_inputs(_model, _tokenizer, _messages):
    text = _tokenizer.apply_chat_template(_messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
    model_inputs = _tokenizer([text], return_tensors="pt").to(_model.device)
    return model_inputs


def generate_model_outputs(_model, _tokenizer, _inputs):
    generated_ids = _model.generate(
        **_inputs,
        #max_new_tokens=4096,
        max_new_tokens=2048,
        temperature = 0.1, top_p=0.95, top_k=20, min_p=0
    )
    output_ids = generated_ids[0][len(_inputs.input_ids[0]):].tolist() 
    try: index = len(output_ids) - output_ids[::-1].index(151668)
    except ValueError: index = 0
    thinking_content = _tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    content = _tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
    return thinking_content, content


################## get fewshotIds via Hashing ##################

def hash_id_asRandom(test_id, gen_len=7, max_id=70):
    res_list = []
    for _ in range(gen_len):
        _id = ((test_id + 11*_) % max_id + round(test_id * 0.7 + 13*_**2) % (max_id//2) + round(test_id * 0.2 - 3*_**3) % (max_id*2)) % max_id
        if _id == test_id: res_list.append(_id + 1)
        else: res_list.append(_id)
    return res_list


################## EVAL benchmark on MMLU-pro ##################

def get_mmluPro_cotPrompt(_DS, test_id, fs_func=None):
    if fs_func == None: fewShot_ids = hash_id_asRandom(test_id, gen_len=4, max_id=65)
    elif isinstance(fs_func, list): fewShot_ids = fs_func
    else: fewShot_ids = fs_func(test_id, gen_len=4, max_id=65)
    

    cot_messages = []
    for fewShot_id in fewShot_ids:
        _item = _DS['default']['validation'][fewShot_id]
        cot_messages += [
            {"role": "user", "content": f"The question is *** {_item['question']} ***, and the options are *** {_item['options']} ***"},
            {"role": "assistant", "content": f"The answer_index is *** {_item['answer_index']} ***, and the Final answer is *** Option {_item['answer']} ***"}
        ]
    return cot_messages

def get_mmluPro_testPrompt(_DS, test_id, fs_func):
    cot_messages = get_mmluPro_cotPrompt(_DS, test_id, fs_func)
    test_item = _DS['default']['test'][test_id]
    cur_messages = cot_messages + [{"role": "user", "content": f"The question is *** {test_item['question']} ***, and the options are *** {test_item['options']} ***"}]
    return cur_messages

def eval_mmluPro_testPrompt(_DS, test_id, model_output):
    test_item = _DS['default']['test'][test_id]
    if '######' in model_output: clean_pred = model_output.split('######')[-1]
    else: clean_pred = model_output[-100:]
    if test_item['answer'] in clean_pred: return True, test_item['answer'], clean_pred
    return False, test_item['answer'], clean_pred


################## EVAL benchmark on GPQA-main ##################

def get_gpqa_cotPrompt(_DS, test_id, fs_func=None):
    if fs_func == None: fewShot_ids = hash_id_asRandom(test_id, gen_len=4, max_id=540)
    elif isinstance(fs_func, list): fewShot_ids = fs_func
    else: fewShot_ids = fs_func(test_id, gen_len=4, max_id=540)

    cot_messages = []
    for fewShot_id in fewShot_ids:
        _item = _DS['gpqa_extended']['train'][fewShot_id]
        _query = _item['Question']
        ans_id = test_id % 4
        options_context = [_item['Incorrect Answer 1'], _item['Incorrect Answer 2'], _item['Incorrect Answer 3']]
        options_context = options_context[:ans_id] + [_item['Correct Answer']] + options_context[ans_id:]
        _options = f"Option A. {options_context[0]}; Option B. {options_context[1]}; Option C. {options_context[2]}; Option D. {options_context[3]}"
        _sampleAns = f"******** Final answer: Option A. {_item['Correct Answer']}"
        cot_messages += [
            {"role": "user", "content": f"The question is *** {_query} ***, and the options are: *** {_options} ***"},
            {"role": "assistant", "content": _sampleAns}
        ]
    return cot_messages

def get_gpqa_testPrompt(_DS, test_id, fs_func):
    cot_messages = get_gpqa_cotPrompt(_DS, test_id, fs_func)
    _item = _DS['gpqa_main']['train'][test_id]
    _query = _item['Question']
    ans_id = test_id % 4
    options_context = [_item['Incorrect Answer 1'], _item['Incorrect Answer 2'], _item['Incorrect Answer 3']]
    options_context = options_context[:ans_id] + [_item['Correct Answer']] + options_context[ans_id:]
    _options = f"Option A. {options_context[0]}; Option B. {options_context[1]}; Option C. {options_context[2]}; Option D. {options_context[3]}"
    cur_messages = cot_messages + [{"role": "user", "content": f"The question is *** {_query} ***, and the options are: *** {_options} ***"}]
    return cur_messages

def eval_gpqa_testPrompt(_DS, test_id, model_output):
    _item = _DS['gpqa_main']['train'][test_id]
    ans_alpha = 'ABCD'[test_id % 4]
    if 'Final answer:' in model_output: clean_pred = model_output.split('Final answer:')[-1]
    else: clean_pred = model_output[-100:]
    if 'Option '+ans_alpha in clean_pred: return True, _item['Correct Answer'], clean_pred
    return False, _item['Correct Answer'], clean_pred


################## EVAL benchmark on GSM-8k ##################

def get_gsm8k_cotPrompt(_DS, test_id, fs_func=None):
    #fewShot_ids = hash_id_asRandom(test_id, gen_len=4, max_id=7400)
    if fs_func == None: fewShot_ids = hash_id_asRandom(test_id, gen_len=4, max_id=7400)
    elif isinstance(fs_func, list): fewShot_ids = fs_func
    else: fewShot_ids = fs_func(test_id, gen_len=4, max_id=7400)

    cot_messages = []
    for fewShot_id in fewShot_ids:
        _item = _DS['main']['train'][fewShot_id]
        cot_messages += [
            {"role": "user", "content": _item['question']},
            {"role": "assistant", "content": _item['answer']}
        ]
    return cot_messages

def get_gsm8k_testPrompt(_DS, test_id, fs_func):
    cot_messages = get_gsm8k_cotPrompt(_DS, test_id, fs_func)
    test_item = _DS['main']['test'][test_id]
    cur_messages = cot_messages + [{"role": "user", "content": test_item['question']}]
    return cur_messages

def eval_gsm8k_testPrompt(_DS, test_id, model_output):
    test_item = _DS['main']['test'][test_id]
    if '####' in model_output: clean_pred = model_output.split('####')[-1]
    else: clean_pred = model_output[-50:]

    if ',' in clean_pred: clean_pred = clean_pred.replace(',', '')
    clean_ans = test_item['answer'].split('####')[-1].strip()
    if clean_ans in clean_pred: return True, clean_ans, clean_pred
    return False, clean_ans, clean_pred


################## EVAL benchmark on MATH-500 ##################

def get_math500_cotPrompt(_DS, test_id, fs_func=None):
    #fewShot_ids = hash_id_asRandom(test_id, gen_len=4, max_id=7400)
    if fs_func == None: fewShot_ids = hash_id_asRandom(test_id, gen_len=4, max_id=7400)
    elif isinstance(fs_func, list): fewShot_ids = fs_func
    else: fewShot_ids = fs_func(test_id, gen_len=4, max_id=7400)

    cot_messages = []
    for fewShot_id in fewShot_ids:
        _item = _DS['default']['train'][fewShot_id]
        if '\\boxed{' in _item['solution']:
            clean_ans = _item['solution'].split('\\boxed{')[1].split('}')[0]
            if len(clean_ans) < 20 and 'frac' not in clean_ans and 'begin' not in clean_ans and 'text' not in clean_ans: 
                cot_messages += [
                    {"role": "user", "content": _item['problem']},
                    {"role": "assistant", "content": _item['solution']}
                ]
    return cot_messages

def get_math500_testPrompt(_DS, test_id, fs_func):
    cot_messages = get_math500_cotPrompt(_DS, test_id, fs_func)
    test_item = _DS['default']['test'][test_id]
    cur_messages = cot_messages + [{"role": "user", "content": test_item['problem']}]
    return cur_messages

def eval_math500_testPrompt(_DS, test_id, model_output):
    test_item = _DS['default']['test'][test_id]
    clean_ans = test_item['solution'].split('\\boxed{')[1].split('}')[0]
    if len(clean_ans) < 20 and 'frac' not in clean_ans and 'begin' not in clean_ans and 'text' not in clean_ans:
        if '\\boxed{' in model_output: clean_pred = model_output.split('\\boxed{')[1].split('}')[0]
        else: clean_pred = model_output[-100:]

        if clean_ans in clean_pred: return True, clean_ans, clean_pred
        return False, clean_ans, clean_pred
    else: return None, None, None


def get_cotPrompt(test_id, _benchmark, _DS, fs_func):
    if _benchmark == 'mmluPro': return get_mmluPro_cotPrompt(_DS, test_id, fs_func=fs_func)
    if _benchmark == 'gpqa-main': return get_gpqa_cotPrompt(_DS, test_id, fs_func=fs_func)
    if _benchmark == 'gsm8k': return get_gsm8k_cotPrompt(_DS, test_id, fs_func=fs_func)
    if _benchmark == 'math-500': return get_math500_cotPrompt(_DS, test_id, fs_func=fs_func)

def get_testPrompt(test_id, _benchmark, _DS, fs_func):
    if _benchmark == 'mmluPro': return get_mmluPro_testPrompt(_DS, test_id, fs_func)
    if _benchmark == 'gpqa-main': return get_gpqa_testPrompt(_DS, test_id, fs_func)
    if _benchmark == 'gsm8k': return get_gsm8k_testPrompt(_DS, test_id, fs_func)
    if _benchmark == 'math-500': return get_math500_testPrompt(_DS, test_id, fs_func)

def eval_testPrompt(test_id, model_output, _benchmark, _DS):
    if _benchmark == 'mmluPro': return eval_mmluPro_testPrompt(_DS, test_id, model_output)
    if _benchmark == 'gpqa-main': return eval_gpqa_testPrompt(_DS, test_id, model_output)
    if _benchmark == 'gsm8k': return eval_gsm8k_testPrompt(_DS, test_id, model_output)
    if _benchmark == 'math-500': return eval_math500_testPrompt(_DS, test_id, model_output)



def access_exemplar_hiddenStates(_model, _tokenizer, _benchmark, _DS, fs_id):

    _exemplars = get_cotPrompt(-1, _benchmark, _DS, fs_func=[fs_id])
    #print(_exemplars)
    text = _tokenizer.apply_chat_template(_exemplars, tokenize=False, add_generation_prompt=True)
    model_inputs = _tokenizer([text], return_tensors="pt").to(_model.device)
    outputs = _model(**model_inputs, output_hidden_states=True)
    hidden_states = outputs.hidden_states

    return hidden_states

def access_testQuery_hiddenStates(_model, _tokenizer, _benchmark, _DS, test_id):

    #_testQuery = get_testPrompt(_DS, test_id, None)[-1]
    _testQuery = get_testPrompt(test_id, _benchmark, _DS, None)[-1:]
    #print(_testQuery)
    text = _tokenizer.apply_chat_template(_testQuery, tokenize=False, add_generation_prompt=True)
    model_inputs = _tokenizer([text], return_tensors="pt").to(_model.device)
    outputs = _model(**model_inputs, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    return hidden_states


def eval_inference_TC(_model, _tokenizer, _benchmark, _DS):
    ###### Eval original model's inference time
    text1 = _tokenizer.apply_chat_template(get_testPrompt(10, _benchmark, _DS, fs_func=None), tokenize=False, add_generation_prompt=True)
    model_inputs1 = _tokenizer([text1], return_tensors="pt").to(_model.device)
    text2 = _tokenizer.apply_chat_template(get_testPrompt(100, _benchmark, _DS, fs_func=None), tokenize=False, add_generation_prompt=True)
    model_inputs2 = _tokenizer([text2], return_tensors="pt").to(_model.device)
    outputs1 = _model(**model_inputs1); outputs2 = _model(**model_inputs2) # warmup
    ss = time.time()
    for _ in range(10): outputs1 = _model(**model_inputs1); outputs2 = _model(**model_inputs2)
    ee = time.time()
    return (ee-ss)/20






