
import sys
import time
from pathlib import Path
from typing import Any, Literal, Optional
import re
import random
import lightning as L
import torch
import base
import torch._dynamo.config
from jsonargparse import CLI

random.seed(22)

def main(
    num_prompts: int = 10,
    top_k: Optional[int] = 200,
    temperature: float = 1.0,
    model_checkpoints: list = [Path("checkpoints/syzymon/long_llama_3b"), Path("checkpoints/openlm-research/open_llama_3b")], 
    quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
    precision: Optional[str] = None,
    compile: bool = False,
):
    model_names = ['longllama_model_verification', 'openllama_model_verification']
    for i in range(num_prompts):
        j = 0
        for model_path in model_checkpoints:
            base.main(checkpoint_dir = model_path, logging_name = model_names[j] + str(i), token_lengths = 512, seed = i+100, view_att=True)
            print(f'\n#### Completed Model: {model_path} ####\n')
            j+=1
        print(f'\n##### Completed Prompt: {i+1}\n')

if __name__ == "__main__":
    CLI(main)
    
