from typing import List
from copy import deepcopy
import numpy as np
import torch
from tqdm import trange
from transformers import AutoTokenizer, AutoModelForCausalLM


@torch.no_grad()
def count_ar_nll(
    model,
    tokenizer,
    generations,
    accelerator,
    batch_size: int = 16,
    average: bool = True,
):  
    if tokenizer.pad_token is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    accumulator = []
    num_batches = 0
    for start_idx in range(0, len(generations), batch_size):
        num_batches += 1
        c_gens = generations[start_idx : start_idx + batch_size]
        for i in range(len(c_gens)):
            c_gens[i] = tokenizer.bos_token + c_gens[i]
        encoded = tokenizer(
            c_gens, padding=True, truncation=True, max_length=128, return_tensors="pt"
        )
        encoded["labels"] = encoded["input_ids"].clone()
        encoded["labels"][encoded["labels"] == tokenizer.pad_token_id] = -100
        for k in encoded.keys():
            encoded[k] = encoded[k].to(accelerator.device)
        loss = model(**encoded).loss.item()
        accumulator.append(loss * len(c_gens))
        # accumulator += loss * len(c_gens) / len(generations)
    
    if average:
        return np.sum(accumulator)/len(generations) 
    accumulator = np.array(accumulator).flatten()
    return accumulator
