import torch
import numpy as np
from tqdm import tqdm
from data_utils import get_test_data
import os
import sys

current_path = os.path.dirname(os.path.abspath(__file__))
parent_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(current_path)

@torch.no_grad()
def ppl_eval(model, tokenizer, datasets=['wikitext2', 'ptb', 'c4'], model_seq_len=2048, batch_size=32, device="cuda", after_prune=True):
    model.to(device)
    model.eval()
    ppls = {}
    for dataset in datasets:
        test_loader = get_test_data(dataset, tokenizer, seq_len=model_seq_len, batch_size = batch_size)
        nlls = []
        for batch in tqdm(test_loader):
            batch = batch.to(device)
            output = model(batch, use_cache=False)
            lm_logits = output.logits
            if torch.isfinite(lm_logits).all():
                shift_logits = lm_logits[:, :-1, :].contiguous()
                shift_labels = batch[:, 1:].contiguous()
                
                loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
                loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.view(-1))
                nlls.append(loss)
        ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
        ppls[dataset] = ppl
    if after_prune:
        print("PPL after pruning: {}".format(ppls))
    else:
        print("PPL before pruning: {}".format(ppls))
    print("Weight Memory: {} MiB\n".format(torch.cuda.memory_allocated()/1024/1024))
    return ppls

@torch.no_grad()
def zero_shot_eval(model, tokenizer, tasks=["piqa", "openbookqa", "hellaswag", "arc_challenge", "arc_easy", "winogrande"], device="cuda"):
    model.to(device)
    model.eval()
    try:
        import lm_eval
        from lm_eval.models.huggingface import HFLM
        
        hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size='auto')
        results = lm_eval.simple_evaluate(hflm, tasks=tasks, batch_size='auto')['results']
        processed_results = {}
        for result in results:
            print(result, f"{results[result]['acc,none']*100:.2f}%")
            processed_results[result] = results[result]['acc,none']
        print("average acc:", sum([results[result]['acc,none'] for result in results])/len(results))
    except Exception as e:
        print("lm-eval-harness evaluation encountered an error:", e)
        print("lm-eval-harness evaluation failed.")
        return {}
    return processed_results
