import csv
from tqdm import tqdm
import os
import random
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
from sklearn import metrics
from datasets import load_metric
import pickle
import torch
import numpy as np
from torch.nn import functional as F
from scipy.stats import entropy
from datasets import Dataset
import json
import sys
import math

def top1_acc(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    metric = load_metric("accuracy")
    return metric.compute(predictions=predictions, references=labels)["accuracy"]

def sp_auroc(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    corrects = (predictions == labels)
    confidences = np.max(logits, axis=-1)
    return metrics.roc_auc_score(corrects, confidences)

def sp_auac(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    corrects = (predictions == labels)
    confidences = np.max(logits, axis=-1)
    _, corrects = list(zip(*sorted(zip(confidences, corrects), reverse=True, key=lambda x: x[0]))) # Sort by conf
    x = np.arange(1, len(corrects) + 1) 
    cumulative_accs = np.cumsum(corrects) / x
    return metrics.auc(x / len(x), cumulative_accs)

def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

def split(iterable, k=1):
    l = len(iterable)
    n = math.ceil(l / k)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

def calculate_ppl(model, encodings):
    max_length = model.config.n_positions
    stride = 512
    nlls = []
    for i in range(0, encodings.size(1), stride):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, encodings.size(1))
        trg_len = end_loc - i    # may be different from stride on last loop
        input_ids = encodings[:, begin_loc:end_loc]
        target_ids = input_ids.clone()
        target_ids[:-trg_len] = -100
        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs[0] * trg_len

        nlls.append(neg_log_likelihood)
    
    ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
    del input_ids
    del outputs
    del nlls
    return ppl.cpu().item()

def calculate_nll(model, encodings):
    with torch.no_grad():
        outputs = model(encodings, labels=encodings)
    loss, logits = outputs[:2]
    sentence_prob = loss.item()
    return sentence_prob

class NoPrint:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout
