import argparse

from tqdm import tqdm

from utils.utils import set_seed
from torch.utils.data import DataLoader
from utils.utils import get_model
from utils.utils import get_promt
import torch.nn.functional as F
from utils.utils import compare_retrieval_acc

import numpy as np
import os
import torch

model_custom_config = {
    "max_new_tokens": 50,
    "temperature": 0.1,
    "top_p": 0.9
}

def main(args):
    device = torch.device(int(args.cuda))

    tokenizer, model = get_model(args.model_path, device, method=args.method, args=args)

    model.eval()
    with torch.no_grad():

        assert args.input_length is not None, "input length must be an Integer"

        query = "hello" * args.input_length

        # 1.29
        query = "hello " * args.input_length
        inputs_token = tokenizer(query, return_tensors="pt").to(model.device)
        input_ids = inputs_token.input_ids
        print("input token length: {}".format(len(input_ids[0])))

        text = tokenizer.decode(inputs_token.input_ids[0])

        output = model(
            input_ids=input_ids[:,10:],
            use_cache=True,
            output_hidden_states=True,
        )
        print("logis: {}".format(output["logits"]))



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/data/persist/models/llama2-7b-chat")
    parser.add_argument("--method", type=str, default="old")
    parser.add_argument("--input_length", type=int, default=8000)
    parser.add_argument("--cuda", type=str, default="1")
    args = parser.parse_args()
    main(args)