from transformers import GPT2Tokenizer, AutoModelForCausalLM
import torch
from torch.utils.data import RandomSampler, DataLoader
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from Datasets import Custom_Dataset
import argparse

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/gpt-neo-1.3B",
    resid_dropout=0,
    embed_dropout=0,
    attention_dropout=0,
    pad_token_id=tokenizer.eos_token_id,
).to(device)

tokenizer.pad_token = tokenizer.eos_token


# def collate_fn(batch):
#     print(batch)
#     return batch


def get_rid_of_pad(tokens):
    while tokens[-1] == -100 or tokens[-1] == tokenizer.pad_token_id:
        tokens.pop()
    return tokens


def _model_call(inps):
    """
    inps: a torch tensor of shape [batch, sequence]
    the size of sequence may vary from call to call
    returns: a torch tensor of shape [batch, sequence, vocab] with the
    logits returned from the model
    """
    with torch.no_grad():
        res = model(inps)
        return res[0][:, :, :]


config = argparse.Namespace()
config.cache_dir = ".cache/"

task = "data/ai2_arc"
subset_path = "ARC-Easy"

dataset = Custom_Dataset(
    dataset_name=task,
    tokenizer=tokenizer,
    valid_subset_path=subset_path,
    type_path="validation",
    input_length=512,
    output_length=512,
    args=config,
)

print(len(dataset))

dataloader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=8,
    shuffle=False,
)


for batch in dataloader:
    choices = batch["choices"]
    source_ids = batch["source_ids"].tolist()
    target_ids = batch["target_ids"]
    batch_size = len(source_ids)
    print("batch_size", batch_size)
    answer_idx = [-1] * batch_size
    for i in range(batch_size):
        answer_idx[i] = batch["answer_index"][i]
    batch_acc = 0
    inps = []
    cont_toks_list = []
    inplens = []
    answers = torch.zeros(batch_size, len(choices), device=device)
    for c_idx in range(len(choices)):
        choice_ids = tokenizer.batch_encode_plus(
            list(choices[c_idx]),
            max_length=512,
            add_special_tokens=False,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )["input_ids"].tolist()
        for i in range(batch_size):
            context_enc = get_rid_of_pad(source_ids[i])
            if len(choice_ids[i]) == 0:
                print(batch["choices"])
            continuation_enc = get_rid_of_pad(choice_ids[i])
            # sanity check
            assert len(context_enc) > 0
            assert len(continuation_enc) > 0
            assert len(continuation_enc) <= model.config.max_position_embeddings
            inp = torch.tensor(
                (context_enc + continuation_enc)[-(512):][:-1],
                dtype=torch.long,
            ).to(device)
            (inplen,) = inp.shape
            cont = continuation_enc
            # pad length from seq to padding_length
            inp = torch.cat(
                [
                    inp,  # [seq]
                    # [padding_length - seq]
                    torch.zeros(512 - inplen, dtype=torch.long).to(inp.device)
                    + tokenizer.pad_token_id,
                ],
                dim=0,
            )
            inps.append(inp.unsqueeze(0))  # [1, padding_length]
            cont_toks_list.append(cont)
            inplens.append(inplen)
        batched_inps = torch.cat(inps, dim=0)  # [batch, padding_length
        multi_logits = F.log_softmax(
            _model_call(batched_inps), dim=-1
        )  # [batch, padding_length, vocab]
        cnt = 0
        for logits, inp, inplen, cont_toks in zip(
            multi_logits, inps, inplens, cont_toks_list
        ):
            # Slice to original seq length
            contlen = len(cont_toks)
            original_logits = logits
            # [1, seq, vocab]
            logits = logits[inplen - contlen : inplen].unsqueeze(0)
            # Check if per-token argmax is exactly equal to continuation
            cont_toks = (
                torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0).to(device)
            )  # [1, seq]
            logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                -1
            )  # [1, seq]
            # Answer: (log prob, is-exact-match)
            loss = -float(logits.sum())
            answers[cnt][c_idx] = loss
            cnt += 1
        inps = []
        cont_toks_list = []
        inplens = []
    answer_idx = torch.Tensor(answer_idx).to(device)
    answers = torch.argmin(answers, dim=1)
    batch_acc = int(torch.where(answers == answer_idx, 1, 0).sum())
    batch_acc_avg = batch_acc / batch_size
    print(f"{task}/acc", batch_acc_avg)
