import os,glob, subprocess, re
# os.environ['HF_HOME'] = "../llama_on_glue/checkpoints"
from huggingface_hub import login
# #

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import torch
from tqdm import tqdm
import pandas as pd
import local_datasets.ceval_exam as ceval 
# from human_eval.data import write_jsonl, read_problems
from glue_utils import get_processFunc, get_dataset

def extract_from_logs(log):
    ret = []
    with open(log,"r+") as f:
        lines = ('\n'.join(f.readlines()))
        words = re.split('[()|]', lines)
        ret = []
        for i in words:
            try:
                ret += [float(i)]
            except:
                continue
    return ret, lines

def load_ckpt(ckpt, device):
    dtype = torch.bfloat16
    tokenizer = AutoTokenizer.from_pretrained(ckpt)
    # model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=dtype, device_map={'': device})
    model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=dtype).to(device)
    return model, tokenizer

def test_glue(ckpt, dataset, device, super_small=False):
    model, tokenizer = load_ckpt(ckpt, device)

    print('going to run0',flush=True)
    _, ds_val = get_dataset(dataset)
    output_texts = get_processFunc(dataset)(ds_val)
            
    def moderate(chat):
        print('start tokens',flush=True)
        input_ids = tokenizer.encode(chat, return_tensors="pt").to(device)
        print('end tokens',flush=True)
        output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
        print('end gen',flush=True)
        prompt_len = input_ids.shape[-1]
        return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)

    succ, tot = 0, 0
    outputs = []
    
    print('going to run1',flush=True)
    possible_answers = []
    for i in output_texts:
        comp = '<|start_header_id|>assistant<|end_header_id|>'
        p = i.find(comp)+len(comp)
        chat, label = i[:p],i[p:]
        label = label.lower().strip(' ').split(' ')[0]
        possible_answers += [label]
    possible_answers = list(set(possible_answers))
    print('going to run2',flush=True)
    for i in tqdm(output_texts):
        print('start',flush=True)
        comp = '<|start_header_id|>assistant<|end_header_id|>'
        p = i.find(comp)+len(comp)
        chat, label = i[:p],i[p:]
        ans = moderate(chat)
        print('generation done',flush=True)
        outputs += [ans]
        tot += 1
        label = label.lower().strip(' ').split(' ')[0]
        
        # strict ver.
        # ans = ans.lower().strip(' ').split(' ')[0]
        # if ans == label:
        #     succ += 1
        
        # not strict ver.
        def F(s):
            if ans.find(s)<0:
                return 1e9
            return ans.find(s)
        print('test answers',flush=True)
        flag = (F(label)<1e8)
        for pa in possible_answers:
            if F(pa) < F(label):
                flag = 0
                break
        succ += flag
        
        if super_small and tot>63:
            break
        if super_small:
            print('-------')
            print(ans)
            print('-------')
            print(label)
            print('-------')
    del model, tokenizer
    torch.cuda.empty_cache()
    return succ / tot, '\n\n'.join(outputs)


def test_toxic_chat(ckpt, device):
    model, tokenizer = load_ckpt(ckpt, device)
    dataset = load_dataset("lmsys/toxic-chat", "toxicchat0124", split='test')
    template_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-Guard-2-8B")
    def moderate(chat):
        chat = template_tokenizer.apply_chat_template(chat, tokenize=False)
        input_ids = tokenizer.encode(chat, return_tensors="pt").to(device)
        output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
        prompt_len = input_ids.shape[-1]
        return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)

    # ans = moderate([
    #     {"role": "user", "content": "I forgot how to kill a process in Linux, can you help?"},
    #     {"role": "assistant", "content": "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate."},
    # ])
    succ, tot = 0, 0
    tp, fp, fn = 0,0,0 
    outputs = []
    for i in tqdm(dataset):
        ans = moderate([
            {"role": "user", "content": i['user_input']},
            {"role": "assistant", "content": i['model_output']},
        ])
        outputs += [ans]
        tot += 1
        if (not ('unsafe' in ans)) ^ i['toxicity']:
            succ += 1
        if i['toxicity']:
            if ('unsafe' in ans):
                tp += 1
            else:
                fn += 1
        else:
            if ('unsafe' in ans):
                fp += 1
        # if tot>10:
        #     break
    del model, tokenizer
    torch.cuda.empty_cache()
    p, r = tp/(tp+fp), tp/(tp+fn)
    return 2*(p*r)/(p+r), '\n\n'.join(outputs)
    # return succ / tot, '\n\n'.join(outputs)

def test_ceval(ckpt, device):
    def into_multiple(d, no_answer=False):
        question = d['question']
        options = [f'{i}.{d[i]}' for i in 'ABCD']
        chat = '题目：' + question + '\n' + '\n'.join(options)+'\n'+'答案：'
        if not no_answer:
            chat += d['answer']
        return chat+'\n'
    
    model, tokenizer = load_ckpt(ckpt, device)
    succ, tot = 0, 0
    outputs = []
    for excel in glob.glob('local_datasets/ceval_exam/val/*.csv'):
        # print(excel)
        df = pd.read_csv(excel)
        subject = excel.split('/')[-1].split('.')[0].replace('_val','')
        sample_df = pd.read_csv(f'local_datasets/ceval_exam/dev/{subject}_dev.csv').iloc[1:].iterrows()
        
        # warning: only consider part of it
        # warning: only consider part of it
        # warning: only consider part of it
        
        if 'middle_school' not in subject:
            continue
        # warning: only consider part of it
        # warning: only consider part of it
        # warning: only consider part of it
        if subject in ceval.name_en2zh.keys():
            subject = ceval.name_en2zh[subject]
        else:
            continue
        context = f'以下是中国关于{subject}考试的单项选择题，请选出其中的正确答案。\n\n'
        context = context + '\n'.join([into_multiple(d) for _, d in sample_df])
        for _, d in tqdm(list(df.iloc[1:].iterrows())):
            prompt = context+'\n'+into_multiple(d,no_answer=True)
            
            input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
            output = model.generate(input_ids=input_ids, max_new_tokens=512, pad_token_id=0)
            prompt_len = input_ids.shape[-1]
            answer = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
            outputs += [answer]
            if ceval.extract_choice(answer) == d['answer']:
                succ += 1
            tot += 1
    del model, tokenizer
    torch.cuda.empty_cache()
    return succ / tot, '\n\n'.join(outputs)

# def test_humaneval(ckpt, device):
#     model, tokenizer = load_ckpt(ckpt, device)
#     outputs = []
#     def generate_one_completion(prompt):
#         prompt = f'''### System:
# Below is an instruction that describes a task, Write a response that appropriately completes the request.

# ### Instruction:
# {prompt}

# ### Response:\n'''
#         input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
#         output = model.generate(input_ids=input_ids, max_new_tokens=512, pad_token_id=0)
#         prompt_len = input_ids.shape[-1]
#         answer = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
#         # if answer[0]==':':
#         #     answer = answer[1:]
#         # find first function
#         answer = answer[answer.find('def'):]
        
        
#         # from https://github.com/abacaj/code-eval/blob/main/core/evaluation.py
#         def filter_code(completion: str) -> str:
#             # The program tends to overwrite, we only take the first function
#             completion = completion.lstrip("\n")
#             return completion.split("\n\n")[0]
#         def fix_indents(text: str) -> str:
#             return text.replace("\t", "    ").replace('`',"\'")

#         answer = filter_code(fix_indents(answer))
        
#         answer_list = answer.split('\n')
#         for i in range(1,len(answer_list)):
#             if len(answer_list[i])>0 and answer_list[i][0]!=' ':
#                 answer = '\n'.join(answer_list[:i-1])
#                 break
            
#         nonlocal outputs
#         outputs += [answer]
#         return answer
#     problems = read_problems()
#     num_samples_per_task = 10
#     samples = [
#         dict(task_id=task_id, completion=generate_one_completion(problems[task_id]["prompt"]))
#         for task_id in tqdm(problems)
#         for _ in range(num_samples_per_task)
#     ]
#     write_jsonl(f"partial/samples.jsonl__{device}", samples)
#     cmd = f'evaluate_functional_correctness partial/samples.jsonl__{device} > partial/output.jsonl__{device}'
#     subprocess.call([cmd],shell=True)
#     ret,lines = extract_from_logs(f'partial/output.jsonl__{device}')
#     return ret[-2], '\n\n'.join(outputs)
        
        

# def test_medQA(model, tokenizer):
#     dataset = load_dataset("medalpaca/medical_meadow_medqa")
#     # template_tokenizer = AutoTokenizer.from_pretrained("ruslanmv/Medical-Llama3-8B")
#     def askme(question):
#         # sys_message = ''' 
#         # You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
#         # provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
#         # '''   
#         # messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]
        
#         # # prompt = template_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
#         # # prompt = sys_message + question
#         # prompt = ''' 
#         # Q:A pulmonary autopsy specimen from a 58-year-old woman who died of acute hypoxic respiratory failure was examined. She had recently undergone surgery for a fractured femur 3 months ago. Initial hospital course was uncomplicated, and she was discharged to a rehab facility in good health. Shortly after discharge home from rehab, she developed sudden shortness of breath and had cardiac arrest. Resuscitation was unsuccessful. On histological examination of lung tissue, fibrous connective tissue around the lumen of the pulmonary artery is observed. Which of the following is the most likely pathogenesis for the present findings?? {'A': 'Thromboembolism', 'B': 'Pulmonary ischemia', 'C': 'Pulmonary hypertension', 'D': 'Pulmonary passive congestion', 'E': 'Pulmonary hemorrhage'},
#         # Please answer with one of the option in the bracket.
#         # '''
#         inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
#         outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)
        
#         response_text = tokenizer.batch_decode(outputs)[0].strip()
#         answer = response_text.split('<|im_start|>assistant')[-1].strip()
#         return response_text


#     question = '''I'm a 35-year-old male and for the past few months, I've been experiencing fatigue, 
#     increased sensitivity to cold, and dry, itchy skin. 
#     Could these symptoms be related to hypothyroidism? 
#     If so, what steps should I take to get a proper diagnosis and discuss treatment options?'''

#     return askme(question)


lm_eval_sets = ['kmmlu', 'minerva_math', 'minerva_math_algebra', 'arc_it', 'kobest', 'arc_easy']    
def test_through_lm_eval(ckpt, dataset, device):
    os.environ['NCCL_P2P_DISABLE'] = '1'
    os.environ['NCCL_IB_DISABLE'] = '1'
    cmd = 'lm_eval --model hf ' \
     + f' --model_args pretrained={ckpt},dtype=bfloat16' \
     + f' --device {device} --batch_size 1 ' \
     + f' --tasks {dataset} > partial/lm_eval_{device}.txt '
    subprocess.call([cmd],shell=True)
    ret, lines = extract_from_logs(f'partial/lm_eval_{device}.txt')
    
    return ret[-2], lines
    
    
def test(ckpt, dataset, device):
    return test_glue(ckpt, dataset, device)
    # if dataset in lm_eval_sets:
    #     return test_through_lm_eval(ckpt, dataset, device)
    # if dataset == 'toxic-chat':
    #     return test_toxic_chat(ckpt, device)
    # if dataset == 'ceval':
    #     return test_ceval(ckpt, device)
    # if dataset == 'humaneval':
    #     return test_humaneval(ckpt, device)
    
    assert 0,f"Dataset {dataset} not found!"

if __name__ == '__main__':
    
    # cc = '../llama_on_glue/checkpoints/hub/my-llama--llama-Llama-3.2-1Bsst2/snapshots/whateverjustsomething/checkpoint-4000'
    # cc = 'checkpoints/hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/5f0b02c75b57c5855da9ae460ce51323ea669d8a'
    cc = '../llama_on_glue/checkpoints/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
    # test(cc, 'humaneval', 'cuda:1')
    cc = '../llama_on_glue/checkpoints/hub/my-llama--llama-Llama-3.2-1Bcola/snapshots/whateverjustsomething/checkpoint-4000'
    
    from models_and_datas import models_and_datas
    # cs = ['sst2','cola','rte','qnli']
    cs = ['sst2']
    for i in cs:
        def model_name_to_ckpt(model_name):
            path = 'checkpoints/hub/'+model_name.replace('/','--')
            path += '/snapshots'
            path = '../llama_on_glue/' + path
            paths = os.listdir(path)
            path = path+'/'+paths[0]
            paths = os.listdir(path)
            path = path+'/'+sorted(paths)[-1]
            return path
        
        # cc = model_name_to_ckpt('my-llama--llama-Llama-3.2-1Bsst2'.replace('sst2',i))
        cc = model_name_to_ckpt('my-llama/llama-Meta-Llama-3-8Bsst2'.replace('sst2',i))
        a,_ = test_glue(cc,i,'cuda:1',super_small=True)
        print(f'metric for {i}:')
        print(a)
    exit(0)
    with open(f'llama-base_on_code.txt',"w+") as f:
            ans, logs = test(cc, 'humaneval', 'cuda:1')
            f.write(str(ans))
            f.write(logs)
            f.write('\n')
    exit(0)
    # with open('llama-base.txt',"w+") as f:
    #     f.write(str(test(cc, 'ceval', 'cuda:1')[0]))
    #     f.write('\n')
    #     f.write(str(test(cc, 'toxic-chat', 'cuda:1')[0]))
    #     f.write('\n')
    #     f.write(str(test(cc, 'haerae', 'cuda:1')[0]))
    #     f.write('\n')
    
    # def model_name_to_ckpt(model_name):
    #     path = 'checkpoints/hub/models--'+model_name.replace('/','--')
    #     path += '/snapshots'
    #     path = glob.glob(path+'/*')[0]
    #     return path
    # from models_and_datas import models_and_datas
    # # for i in models_and_datas:
    # for i in ['code']:
    #     if 0==len(models_and_datas[i]['datasets']):
    #         continue
    #     for m in models_and_datas[i]['model']:
    #         cc = model_name_to_ckpt(m)
            
    #     with open(f'llama-{i}.txt',"w+") as f:
    #         ans, logs = test(cc, models_and_datas[i]['datasets'][0], 'cuda:1')
    #         f.write(str(ans))
    #         f.write(logs)
    #         f.write('\n')
    

# export NCCL_P2P_DISABLE="1"
# export NCCL_IB_DISABLE="1"
# lm_eval --model hf \
#     --model_args pretrained=checkpoints/hub/models--meta-llama--Meta-Llama-Guard-2-8B/snapshots/7d257f3c1a0ec6ed99b2cb715027149dfb9784ef \
#     --tasks haerae \
#     --device cuda \
#     --batch_size 1