import os,glob, subprocess, re, evaluate
# os.environ['HF_HOME'] = "../llama_on_glue/checkpoints"
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:8118'
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:8118'
from huggingface_hub import login
# #

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForSeq2SeqLM
from datasets import load_dataset
import torch
from tqdm import tqdm
import pandas as pd
import local_datasets.ceval_exam as ceval 
# from human_eval.data import write_jsonl, read_problems
from glue_utils import get_processFunc, get_dataset
# from evaluations_t5 import test_t5
batch_dict = {
    'sst2':256,
    'cola':256,
    'rte':128,
    'qnli':256,
    'mrpc':256,
    'wnli':256,
    'qqp':128,
    'mnli':128,
    'code_to_text':16,
    'text_to_code':16,
    'defect_detection':16,
    'clone_detection':16,
    
}

def extract_from_logs(log):
    ret = []
    with open(log,"r+") as f:
        lines = ('\n'.join(f.readlines()))
        words = re.split('[()|]', lines)
        ret = []
        for i in words:
            try:
                ret += [float(i)]
            except:
                continue
    return ret, lines

def load_ckpt(ckpt, device, tokenizer_ckpt=None):
    dtype = torch.bfloat16
    if tokenizer_ckpt is None:
        tokenizer = AutoTokenizer.from_pretrained(ckpt)
    else:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_ckpt)
    
    if 't5' not in ckpt:
        model = AutoModelForCausalLM.from_pretrained(ckpt).to(device)
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(ckpt).to(device)
    return model, tokenizer

no_need_possible_answers = ['code_to_text','text_to_code']
rouge = evaluate.load('rouge')
def solve_ans_label(ans, label, dataset, possible_answers, super_small=False):
    if 'Lots-of-LoRAs' in dataset:
        results = rouge.compute(predictions=[ans], references=[[label]])
        return results['rougeL']
    elif dataset == 'code_to_text':
        bleu = evaluate.load("bleu")
        results = bleu.compute(predictions=[ans], references=[[label]])
        return results['bleu']
    elif dataset == 'text_to_code':
        bleu = evaluate.load("bleu")
        results = bleu.compute(predictions=[ans], references=[[label]])
        return results['bleu']
    else:
        def F(s):
            if ans.find(s)<0:
                return 1e9
            return ans.find(s)
        if super_small:
            print(f'-----{ans}-----')
        if (F(label)>1e8):
            if super_small:
                print(f'<<<{label}:{F(label)}>>>\n\n')
            return 0
        for pa in possible_answers:
            if F(pa) < F(label):
                if super_small:
                    print(f'<<<{pa}:{F(pa)},,,{label}:{F(label)}>>>\n\n')
                return 0
        return 1
        

def test_glue(ckpt, dataset, device, super_small=False, model=None, 
              tokenizer=None, smaller_batch=1, tokenizer_ckpt=None):
    
    print(ckpt)

    if model is None:
        model, tokenizer = load_ckpt(ckpt, device, tokenizer_ckpt=tokenizer_ckpt)
    if tokenizer_ckpt is not None:
        ckpt2 = ckpt + tokenizer_ckpt
    else:
        ckpt2 = ckpt
    if ('t5' in ckpt2 or 'T5' in ckpt2) and 'flan' not in ckpt2:
        None
        # return test_t5(ckpt, dataset, device, super_small, model, tokenizer, smaller_batch)
    elif 'qwen' in ckpt2 or 'Qwen' in ckpt2:
        prompt = '''<|im_start|>system
{}<|im_end|>
<|im_start|>user
{}<|im_end|>
<|im_start|>assistant
{}<|im_end|>'''
        response_template = "<|im_start|>assistant"
        pad_token = '<|endoftext|>'
    elif 'Mistral' in ckpt2:
        prompt =  '''<s>[INST] {} [/INST] {} </s>'''
        response_template = "[/INST]"
        pad_token = '<unk>'
    elif 'deepseek' in ckpt2:
        prompt = '''<｜begin▁of▁sentence｜> User: {}
{}
Assistant: {} <｜end▁of▁sentence｜>'''
        response_template = "Assistant: "
        pad_token = tokenizer.pad_token
    elif 'flan' in ckpt2:
        prompt = '''{}
{}
Answer: {}'''
        response_template = "Answer: "
        pad_token = tokenizer.pad_token
    elif 'llama' in ckpt2 or 'Llama' in ckpt2:
        # which will definitely happen because folder name has llama 2333
        prompt = '''<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{} <|eot_id|><|start_header_id|>user<|end_header_id|>
{} <|eot_id|><|start_header_id|>assistant<|end_header_id|>
{} <|eot_id|>'''
        response_template = "<|start_header_id|>assistant<|end_header_id|>"
        pad_token = '<|reserved_special_token_0|>'
    

            
    tokenizer.max_seq_length = 4096
    tokenizer.padding_side='left'
    tokenizer.pad_token = pad_token
    # print(tokenizer.pad_token_id)
    # exit(0)
    # print('going to run0',flush=True)
    _, ds_val = get_dataset(dataset)
    output_texts = get_processFunc(dataset)(ds_val, prompt)
            
    def moderate(batch):
        # print(len(batch))
        chats = list(map(lambda x:x[0],batch))
        # '<|reserved_special_token_0|>'
        inputs = tokenizer(chats, padding="longest", return_tensors="pt")
        inputs = {key: val.to(model.device) for key, val in inputs.items()}
        temp_texts=tokenizer.batch_decode(inputs["input_ids"], skip_special_tokens=True)
        if dataset in ['sst2','cola','rte', 'qnli','mnli','mrpc','wnli','qqp']:
            gen_tokens = model.generate(**inputs, max_new_tokens=32, pad_token_id=tokenizer.pad_token_id)
        else:
            gen_tokens = model.generate(**inputs, max_new_tokens=512, pad_token_id=tokenizer.pad_token_id)
        
        gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
        if 'flan' not in ckpt2:
            gen_text = [i[len(temp_texts[idx]):] for idx, i in enumerate(gen_text)]
        return gen_text
        # input_ids = tokenizer.encode(chat, return_tensors="pt").to(device)
        # output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
        # prompt_len = input_ids.shape[-1]
        # return tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)

    succ, tot = 0, 0
    outputs = []
    
    # print('going to run1',flush=True)
    possible_answers = []
    try:
        batch_size = batch_dict[dataset]
    except:
        print(f'{dataset} not in batch_dict! seting batchsize=32')
        batch_size = 32
    if smaller_batch > 0:
        batch_size = batch_size // smaller_batch
    else:
        batch_size = - batch_size * smaller_batch
    if batch_size<1:
        batch_size = 1

    batches = [[]]
    for i in output_texts:
        comp = response_template
        p = i.find(comp)+len(comp)
        chat, label = i[:p],i[p:]
        
        if dataset not in no_need_possible_answers and 'Mistral' not in ckpt2:
            label = label.lower().strip('\n').strip(' ').strip('<|im_end|>').split(' ')[0]
            possible_answers += [label]
        if 'Mistral' in ckpt2:
            label = label.strip('</s>')
        batches[-1] += [(chat, label)]
        if len(batches[-1])==batch_size:
            batches += [[]]
    possible_answers = list(set(possible_answers))
    for batch in tqdm(batches):
        if super_small and tot>20:
            break
        if len(batch):
            batch_ans = moderate(batch)
            outputs += batch_ans
            for ans, (chat, label) in zip(batch_ans,batch):
                if super_small:
                    print('-------')
                    print(ans)
                    print('-------')
                    print(label)
                    print('-------')
                
                succ += solve_ans_label(ans, label, dataset, possible_answers, super_small=super_small)
                tot += 1
        
    del model, tokenizer
    torch.cuda.empty_cache()
    if super_small:
        print(succ / tot)
    return succ / tot, '\n\n'.join(outputs)

    
def test(ckpt, dataset, device):
    return test_glue(ckpt, dataset, device)
    # if dataset in lm_eval_sets:
    #     return test_through_lm_eval(ckpt, dataset, device)
    # if dataset == 'toxic-chat':
    #     return test_toxic_chat(ckpt, device)
    # if dataset == 'ceval':
    #     return test_ceval(ckpt, device)
    # if dataset == 'humaneval':
    #     return test_humaneval(ckpt, device)
    
    assert 0,f"Dataset {dataset} not found!"
