import os,glob, subprocess, re
# os.environ['HF_HOME'] = "../llama_on_glue/checkpoints"
from huggingface_hub import login
# #

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import torch
from tqdm import tqdm
import pandas as pd
import local_datasets.ceval_exam as ceval 
# from human_eval.data import write_jsonl, read_problems
from glue_utils import get_processFunc, get_dataset
batch_dict = {
    'sst2':64,
    'cola':64,
    'rte':32,
    'qnli':64,
    'mrpc':64,
    'wnli':64,
    'qqp':32,
    'mnli':32,
    
    
}

def extract_from_logs(log):
    ret = []
    with open(log,"r+") as f:
        lines = ('\n'.join(f.readlines()))
        words = re.split('[()|]', lines)
        ret = []
        for i in words:
            try:
                ret += [float(i)]
            except:
                continue
    return ret, lines

def load_ckpt(ckpt, device):
    dtype = torch.bfloat16
    tokenizer = AutoTokenizer.from_pretrained(ckpt)
    # model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=dtype, device_map={'': device})
    model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=dtype).to(device)
    # print(tokenizer.pad_token)
    # exit(0)
    return model, tokenizer

def test_glue(ckpt, dataset, device, super_small=False):
    model, tokenizer = load_ckpt(ckpt, device)
    tokenizer.padding_side='left'
    tokenizer.pad_token = '<|reserved_special_token_0|>'
    # print(tokenizer.pad_token_id)
    # exit(0)
    # print('going to run0',flush=True)

    ### change to training set evaluation!
    _, ds_val = get_dataset(dataset)
    ds_val, _ = get_dataset(dataset)
    
    output_texts = get_processFunc(dataset)(ds_val)
            
    def moderate(batch):
        # print(len(batch))
        chats = list(map(lambda x:x[0],batch))
        # '<|reserved_special_token_0|>'
        inputs = tokenizer(chats, padding="longest", return_tensors="pt")
        inputs = {key: val.to(model.device) for key, val in inputs.items()}
        temp_texts=tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True)
        
        gen_tokens = model.generate(**inputs, max_new_tokens=512, pad_token_id=128008)
        gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
        gen_text = [i[len(temp_texts[idx]):] for idx, i in enumerate(gen_text)]
        return gen_text
        # input_ids = tokenizer.encode(chat, return_tensors="pt").to(device)
        # output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
        # prompt_len = input_ids.shape[-1]
        # return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)

    succ, tot = 0, 0
    outputs = []
    
    # print('going to run1',flush=True)
    possible_answers = []
    batch_size = batch_dict[dataset]
    batches = [[]]
    for i in output_texts:
        comp = '<|start_header_id|>assistant<|end_header_id|>'
        p = i.find(comp)+len(comp)
        chat, label = i[:p],i[p:]
        label = label.lower().strip(' ').split(' ')[0]
        possible_answers += [label]
        batches[-1] += [(chat, label)]
        if len(batches[-1])==batch_size:
            batches += [[]]
    possible_answers = list(set(possible_answers))
    for batch in tqdm(batches):
        if super_small and tot>20:
            break
        if len(batch):
            batch_ans = moderate(batch)
            outputs += batch_ans
            for ans, (chat, label) in zip(batch_ans,batch):
                if super_small:
                    print('-------')
                    print(ans)
                    print('-------')
                    print(label)
                    print('-------')
                def F(s):
                    if ans.find(s)<0:
                        return 1e9
                    return ans.find(s)
                flag = (F(label)<1e8)
                for pa in possible_answers:
                    if F(pa) < F(label):
                        flag = 0
                        break
                succ += flag
                tot += 1
        
    del model, tokenizer
    torch.cuda.empty_cache()
    return succ / tot, '\n\n'.join(outputs)


def test_toxic_chat(ckpt, device):
    model, tokenizer = load_ckpt(ckpt, device)
    dataset = load_dataset("lmsys/toxic-chat", "toxicchat0124", split='test')
    template_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-Guard-2-8B")
    def moderate(chat):
        chat = template_tokenizer.apply_chat_template(chat, tokenize=False)
        input_ids = tokenizer.encode(chat, return_tensors="pt").to(device)
        output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
        prompt_len = input_ids.shape[-1]
        return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)

    # ans = moderate([
    #     {"role": "user", "content": "I forgot how to kill a process in Linux, can you help?"},
    #     {"role": "assistant", "content": "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate."},
    # ])
    succ, tot = 0, 0
    tp, fp, fn = 0,0,0 
    outputs = []
    for i in tqdm(dataset):
        ans = moderate([
            {"role": "user", "content": i['user_input']},
            {"role": "assistant", "content": i['model_output']},
        ])
        outputs += [ans]
        tot += 1
        if (not ('unsafe' in ans)) ^ i['toxicity']:
            succ += 1
        if i['toxicity']:
            if ('unsafe' in ans):
                tp += 1
            else:
                fn += 1
        else:
            if ('unsafe' in ans):
                fp += 1
        # if tot>10:
        #     break
    del model, tokenizer
    torch.cuda.empty_cache()
    p, r = tp/(tp+fp), tp/(tp+fn)
    return 2*(p*r)/(p+r), '\n\n'.join(outputs)
    # return succ / tot, '\n\n'.join(outputs)

def test_ceval(ckpt, device):
    def into_multiple(d, no_answer=False):
        question = d['question']
        options = [f'{i}.{d[i]}' for i in 'ABCD']
        chat = '题目：' + question + '\n' + '\n'.join(options)+'\n'+'答案：'
        if not no_answer:
            chat += d['answer']
        return chat+'\n'
    
    model, tokenizer = load_ckpt(ckpt, device)
    succ, tot = 0, 0
    outputs = []
    for excel in glob.glob('local_datasets/ceval_exam/val/*.csv'):
        # print(excel)
        df = pd.read_csv(excel)
        subject = excel.split('/')[-1].split('.')[0].replace('_val','')
        sample_df = pd.read_csv(f'local_datasets/ceval_exam/dev/{subject}_dev.csv').iloc[1:].iterrows()
        
        # warning: only consider part of it
        # warning: only consider part of it
        # warning: only consider part of it
        
        if 'middle_school' not in subject:
            continue
        # warning: only consider part of it
        # warning: only consider part of it
        # warning: only consider part of it
        if subject in ceval.name_en2zh.keys():
            subject = ceval.name_en2zh[subject]
        else:
            continue
        context = f'以下是中国关于{subject}考试的单项选择题，请选出其中的正确答案。\n\n'
        context = context + '\n'.join([into_multiple(d) for _, d in sample_df])
        for _, d in tqdm(list(df.iloc[1:].iterrows())):
            prompt = context+'\n'+into_multiple(d,no_answer=True)
            
            input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
            output = model.generate(input_ids=input_ids, max_new_tokens=512, pad_token_id=0)
            prompt_len = input_ids.shape[-1]
            answer = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
            outputs += [answer]
            if ceval.extract_choice(answer) == d['answer']:
                succ += 1
            tot += 1
    del model, tokenizer
    torch.cuda.empty_cache()
    return succ / tot, '\n\n'.join(outputs)



lm_eval_sets = ['kmmlu', 'minerva_math', 'minerva_math_algebra', 'arc_it', 'kobest', 'arc_easy']    
def test_through_lm_eval(ckpt, dataset, device):
    os.environ['NCCL_P2P_DISABLE'] = '1'
    os.environ['NCCL_IB_DISABLE'] = '1'
    cmd = 'lm_eval --model hf ' \
     + f' --model_args pretrained={ckpt},dtype=bfloat16' \
     + f' --device {device} --batch_size 1 ' \
     + f' --tasks {dataset} > partial/lm_eval_{device}.txt '
    subprocess.call([cmd],shell=True)
    ret, lines = extract_from_logs(f'partial/lm_eval_{device}.txt')
    
    return ret[-2], lines
    
    
def test(ckpt, dataset, device):
    return test_glue(ckpt, dataset, device)
    # if dataset in lm_eval_sets:
    #     return test_through_lm_eval(ckpt, dataset, device)
    # if dataset == 'toxic-chat':
    #     return test_toxic_chat(ckpt, device)
    # if dataset == 'ceval':
    #     return test_ceval(ckpt, device)
    # if dataset == 'humaneval':
    #     return test_humaneval(ckpt, device)
    
    assert 0,f"Dataset {dataset} not found!"

if __name__ == '__main__':
    
    # cc = '../llama_on_glue/checkpoints/hub/my-llama--llama-Llama-3.2-1Bsst2/snapshots/whateverjustsomething/checkpoint-4000'
    # cc = 'checkpoints/hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/5f0b02c75b57c5855da9ae460ce51323ea669d8a'
    cc = '../llama_on_glue/checkpoints/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920'
    # test(cc, 'humaneval', 'cuda:1')
    # cc = '../llama_on_glue/checkpoints/hub/my-llama--llama-Llama-3.2-1Bsst2/snapshots/whateverjustsomething/checkpoint-4000'
    # i='sst2'
    # a,_ = test_glue(cc,i,'cuda:2',super_small=True)
    # print(f'metric for {i}:')
    # print(a)
    # exit(0)
    # cs = ['sst2','cola','rte','qnli','mrpc','wnli','qqp','mnli']
    cs = ['rte']
    for i in cs:
        def model_name_to_ckpt(model_name):
            path = 'checkpoints/hub/'+model_name.replace('/','--')
            path += '/snapshots'
            path = '../llama_on_glue/' + path
            paths = os.listdir(path)
            path = path+'/'+paths[0]
            paths = os.listdir(path)
            path = path+'/'+sorted(paths)[-1]
            return path
        
        # cc = model_name_to_ckpt('my-llama--llama-Llama-3.2-1Bsst2'.replace('sst2',i))
        cc = model_name_to_ckpt('my-llama/llama-Meta-Llama-3-8Bsst2'.replace('sst2',i))
        a,_ = test_glue(cc,i,'cuda:0', super_small=True)
        print(f'metric for {i}:')
        print(a)
    exit(0)
    with open(f'llama-base_on_code.txt',"w+") as f:
            ans, logs = test(cc, 'humaneval', 'cuda:1')
            f.write(str(ans))
            f.write(logs)
            f.write('\n')
    exit(0)
    # with open('llama-base.txt',"w+") as f:
    #     f.write(str(test(cc, 'ceval', 'cuda:1')[0]))
    #     f.write('\n')
    #     f.write(str(test(cc, 'toxic-chat', 'cuda:1')[0]))
    #     f.write('\n')
    #     f.write(str(test(cc, 'haerae', 'cuda:1')[0]))
    #     f.write('\n')
    
    # def model_name_to_ckpt(model_name):
    #     path = 'checkpoints/hub/models--'+model_name.replace('/','--')
    #     path += '/snapshots'
    #     path = glob.glob(path+'/*')[0]
    #     return path
    # from models_and_datas import models_and_datas
    # # for i in models_and_datas:
    # for i in ['code']:
    #     if 0==len(models_and_datas[i]['datasets']):
    #         continue
    #     for m in models_and_datas[i]['model']:
    #         cc = model_name_to_ckpt(m)
            
    #     with open(f'llama-{i}.txt',"w+") as f:
    #         ans, logs = test(cc, models_and_datas[i]['datasets'][0], 'cuda:1')
    #         f.write(str(ans))
    #         f.write(logs)
    #         f.write('\n')
    

# export NCCL_P2P_DISABLE="1"
# export NCCL_IB_DISABLE="1"
# lm_eval --model hf \
#     --model_args pretrained=checkpoints/hub/models--meta-llama--Meta-Llama-Guard-2-8B/snapshots/7d257f3c1a0ec6ed99b2cb715027149dfb9784ef \
#     --tasks haerae \
#     --device cuda \
#     --batch_size 1
