
import os,glob, subprocess, re, evaluate
from huggingface_hub import login
import argparse
# #

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,AutoModelForSeq2SeqLM
from datasets import load_dataset
import torch
import numpy as np
from tqdm import tqdm
import pandas as pd
import local_datasets.ceval_exam as ceval 
# from human_eval.data import write_jsonl, read_problems
from glue_utils import get_processFunc, get_dataset
from glue_tasks import get_glue_tasks, for_Trainer

def load_ckpt(ckpt, device):
    dtype = torch.bfloat16
    tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-base')
    
    model = AutoModelForSeq2SeqLM.from_pretrained(ckpt).to(device)
    return model, tokenizer


answer_list = {
    'rte':[[3,35,5756,297,1],[59,834,35,5756,297,1]],
    'cola':[[9961,1],[29452,1]],
    'mrpc':[[7072,1],[59,834,15,1169,15592,1]],
    'mnli':[[3, 35, 5756, 297, 1],[7163, 1],[27252, 1]],
    'sst2':[[2841,1],[1465,1]],
    'qnli':[[3,35,5756,297,1],[59,834,35,5756,297,1]],
    'qqp':[[19197, 1],[59, 834, 26, 413, 26221, 1]],
    'wnli':[[3,35,5756,297,1],[59,834,35,5756,297,1]],
}
debug=False

REAL_BATCH = {
    'rte':16,
    'wnli':32,
    'cola':256,
    'mrpc':256,
    'mnli':256,
    'sst2':256,
    'qnli':256,
    'qqp':256,
}


def get_metric(model, dataloader, answer_list, no_metric=False, super_small=False):
    if no_metric:
        return 0.0
    model.eval()
    ml = max(list(map(len, answer_list)))
    # metric = evaluate.load("accuracy")
    acc,tot = 0,0
    for id,batch in tqdm(enumerate(dataloader)):
        # print(id)
        # _batch = {k: torch.stack(batch[k],axis=-1).to(model.device) for k in ['input_ids', 'labels']}
        with torch.no_grad():
            source_ids, source_mask, lm_labels, target_mask = batch
            # print(lm_labels)
            lm_labels[lm_labels[:, :] == 0] = -100

            outputs = model.generate(
                input_ids=source_ids.to(model.device),
                attention_mask=source_mask.to(model.device),
                do_sample=False,
                return_dict_in_generate=True,
                output_scores=True,
                min_length=ml+1
            )
        scores = torch.stack(outputs.scores, dim=1).to("cpu").numpy()
        # predictions = np.argmax(np.asarray(logits)[:,0,:], axis=-1)
        # references = np.asarray(batch['labels'])[0,:]
        
        predictions = np.argmax(scores, axis=-1)
        references = np.asarray(lm_labels)[:,:]

        
        orignial_length = scores[0].shape[0]
        scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
        scores = scores / np.sum(scores, axis=-1, keepdims=True)
        for i in range(scores.shape[0]):
            t = references[i]
            good = 1
            # lofits is not softmaxed
            prob = 1
            for k in range(len(t)):
                if t[k]>1:
                    prob *= scores[i,k,t[k]]
                else:
                    break
            if debug:
                print(t,prob)
            for t in answer_list:
                p = 1
                for k in range(len(t)):
                    if t[k]>1:
                        p *= scores[i,k,t[k]]
                    else:
                        break
                    
                if debug:
                    print(t,p)
                if p == prob:
                    same_one = 1
                    for k in range(len(t)):
                        if t[k]>1:
                            if(t[k]!=references[i][k]):
                                same_one = 0
                                break
                    if not same_one:
                        good = 0
                        break
                if p > prob :
                    good = 0
                    break
            if debug:
                print(good)
                print('-'*50)
            acc += good
            tot += 1
            if super_small and tot>30:
                return acc / tot

    return acc / tot

batch_dict = {
    'sst2':64,
    'cola':64,
    'rte':32,
    'qnli':64,
    'mrpc':64,
    'wnli':64,
    'qqp':32,
    'mnli':32,
    'code_to_text':4,
    'text_to_code':4,
    'defect_detection':4,
    'clone_detection':4,
    
}

class empty:
    pass

def test_t5(ckpt, dataset, device, super_small=False, model=None, tokenizer=None, smaller_batch=1):
    if model is None:
        model, tokenizer = load_ckpt(ckpt, device)
    # tokenizer.max_seq_length = 4096
    
    args2 = empty()
    args2.dataset = dataset
    args2.tokenizer = tokenizer
    args2.tokenizer_model = 'already_done'
    if smaller_batch > 0:
        batch_size = batch_dict[dataset] // smaller_batch
    else:
        batch_size = - batch_dict[dataset] * smaller_batch
    # print(batch_size,'as batch size')
    args2.batch = batch_size
    ds_train, ds_val = get_glue_tasks(args2)
    # ds_train, ds_val = for_Trainer(ds_train), for_Trainer(ds_val)    
    return get_metric(model, ds_val, answer_list[dataset], super_small=super_small),"dsafsd"