from model import GPT, eval_loss
import tiktoken
import os
import json
import time
import torch
from fire import Fire
import numpy as np
from tqdm import tqdm, trange
os.environ["TIKTOKEN_CACHE_DIR"] = "./tiktoken_cache"
enc = tiktoken.get_encoding("gpt2")

def load_and_eval(model_path, texts: list[str], batch_size: int = 32):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Loading model from {model_path}...")
    model = GPT.from_pretrained(model_path, device)
    results = []
    for i in trange(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        losses, log_probs = eval_loss(model, batch_texts, enc)
        for text, loss, log_prob in zip(batch_texts, losses, log_probs):
            results.append(loss)
    return results

def load_test_data(setting):
    if setting == "test":
        with open("unbalanced_pretrain/sft_power_test.txt", "r") as f:
            return f.readlines()
    with open("unbalanced_pretrain/test.txt", "r") as f:
        lines = f.readlines()
    if setting == "full":
        return lines
    elif setting == "simple":
        # every 5 lines, return 0, 5, 15...
        return [lines[i] for i in range(0, len(lines), 5)]
    elif setting == "sample":
        # every 100 lines
        return [lines[i] for i in range(0, len(lines), 100)]
    

def main(setting: str, checkpoint: str, path: str, batch_size: int = 256):
    assert setting in ["full", "simple", "sample", "test"]
    if not path.endswith("/"):
        path += "/"
    basepath = path.split("/")[-2]
    print(f"Evaluating {setting} with checkpoint {checkpoint} at path {path}")
    t0 = time.time()
    texts = load_test_data(setting)
    results = load_and_eval(path + checkpoint, texts, batch_size)
    print(f"Evaluation took {time.time() - t0} seconds")
    os.makedirs("logs/eval", exist_ok=True)
    print("Mean Loss", np.mean(results))
    os.makedirs(f"logs/eval/{basepath}", exist_ok=True)
    with open(f"logs/eval/{basepath}/{setting}_{checkpoint}.json", "w") as f:
        f.write(json.dumps(results))
    print(f"Evaluation results saved to logs/eval/{basepath}/{setting}_{checkpoint}.json")
    
if __name__ == "__main__":
    Fire(main)
# python evaluation.py sample state_step005550.pt logs/pretrain_power_new_20250128_112926_5196157b-8ca4-4a7b-bc9a-ed6f0c1442b4 16
