import torch
from transformers import AutoTokenizer
from datasets import load_from_disk
from lm_eval.models.huggingface import HFLM
from lm_eval.api.instance import Instance
from sentence_transformers import util
import numpy as np
import random
import os
import argparse
import math
from TTMM import TTMM

seed = 23
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

parser = argparse.ArgumentParser(description='Trainer Script')
parser.add_argument('--model_name', type=str, default='meta-llama/Llama-3.2-1B', help='model name')
parser.add_argument('--corpus_path', type=str, default='', help='corpus path')
parser.add_argument('--dataset_path', type=str, default='/path/to/dataset', help='path of test dataset')
parser.add_argument('--ls', type=float, default=0.01, help='length scale for kernel model')
parser.add_argument('--threshold', type=float, default=0.01, help='threshold for kernel model')
parser.add_argument('--prefixLen', type=int, default=0, help='Evaluation Prefix Len')
parser.add_argument('--adapter_path', type=str, default='', help='adapter path')
parser.add_argument('--null_adapter_path', type=str, default='/path/to/nullAdapter', help='path to untrained initialized adapter')

args = parser.parse_args()
print(args, flush=True)


os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def build_eval_reqs(strings):
    return [Instance(request_type="loglikelihood_rolling", doc={}, arguments=(string,), idx=i) for i, string in enumerate(strings)]

string = []
string_tokens = []

corpus = torch.load(args.corpus_path, map_location=torch.device('cpu'))
corpusNorm = util.normalize_embeddings(corpus)

adapterLocation=args.adapter_path


dataset = load_from_disk(args.dataset_path)

for doc in dataset:
    string.append(doc["text"])
    string_tokens.append(len(tokenizer.encode(doc["text"], add_special_tokens=False)))

    
model = TTMM(corpus=corpusNorm, device=device, tokenizer=tokenizer, rbf_length_scale=args.ls, rbf_threshold=args.threshold, prefix_length=args.prefixLen, adapterLocation = adapterLocation, verbose=True, modelName=args.model_name, null_adapter=args.null_adapter_path, model_id=args.model_name)

if args.model_name == "Qwen/Qwen2.5-1.5B":
    eval_model = HFLM(pretrained=model, tokenizer=tokenizer, device=device, max_length=8192)
else:
    eval_model = HFLM(pretrained=model, tokenizer=tokenizer, device=device, max_length=16384)


seed = 23
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

def evaluateSelf():
    
    string_limited = [tokenizer.decode(tokenizer.encode(string[i], add_special_tokens=False)[:args.prefixLen], skip_special_tokens=True) for i in range(len(string))]
    eval_reqs = build_eval_reqs(string)
    eval_reqs_limited = build_eval_reqs(string_limited)

    resp = getattr(eval_model, "loglikelihood_rolling")(eval_reqs, disable_tqdm=False)
    
    
    print(f"Likelihoods of full text: {resp}", flush=True)
    resp_limited = getattr(eval_model, "loglikelihood_rolling")(eval_reqs_limited, disable_tqdm=False)

    diff = [xi - yi for xi, yi in zip(resp, resp_limited)]

    print(f"Likelihoods of full text: {resp}", flush=True)        
    print(f"Likelihood limited to prefix_length: {resp_limited}", flush=True)
    print(f"Difference in Likelihoods: {diff}", flush=True)

    ## Computing the Perplexity

    # Step 1: Sum the negative log-likelihoods
    total_nll = -1*sum(diff)

    per_doc_perp = []

    # Step 2: Tokenize the dataset to count tokens
    total_tokens = 0
    for i, doc_diff in enumerate(diff):
        tokens = (string_tokens[i] - min(args.prefixLen,string_tokens[i]))
        total_tokens += (tokens)
        if tokens > 0:
            per_doc_perp.append(math.exp(-1 * doc_diff / tokens))
        if tokens < 0:
            print("Error: tokens < 0", flush=True)
            print("Token Count: ", tokens, flush=True)
            print("Cluster Index: ", i, flush=True)

    # Step 3: Compute the average negative log-likelihood per token
    average_nll = total_nll / total_tokens

    print(average_nll,flush=True)

    # Step 4: Compute perplexity
    perplexity = math.exp(average_nll)

    print(f"Perplexity: {perplexity}",flush=True)
    print(f"Per Doc Perplexities: {per_doc_perp}",flush=True)
    print(f"Mean Per Doc Perplexity: {np.mean(per_doc_perp)}",flush=True)
    print(f"Standard Deviation: {np.std(per_doc_perp)}",flush=True)

    resp = diff


    
evaluateSelf()