import os
from parallel.start import start
from parallel.ppl_utils import get_wikitext2, compute_perplexity
from parallel.config import create_config
import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quant_act", action="store_true")
    parser.add_argument("--quant_kv", action="store_true")
    return parser.parse_args()


def main():
    args = parse_args()

    ckpt_dir = os.environ.get("CKPT_DIR", "/path/to/model/checkpoint")
    is_llama_2 = False
    seqlen = 2048

    q = 14
    quant_act = args.quant_act
    quant_kv = args.quant_kv
    act_betas = [3.47, 4.74, 6.90, 18.11]
    key_betas = [3.50, 4.58, 6.47, 17.06]
    value_betas = [3.53, 5.59, 9.62, 29.03]

    qconfig = create_config(q, quant_act, quant_kv, act_betas, key_betas, value_betas)

    model, tokenizer = start(ckpt_dir, is_llama_2, qconfig)
    wikitext = get_wikitext2(tokenizer=tokenizer, is_testset=True)
    ppl = compute_perplexity(model, wikitext, seqlen)
    print("PPL:", ppl)


if __name__ == "__main__":
    main()
