import torch
from models.llama_kivi import LlamaForCausalLM_KIVI
from models.mistral_kivi import MistralForCausalLM_KIVI
from models.llama_sinkq import LlamaForCausalLM_SinkQ
from models.mistral_sinkq import MistralForCausalLM_SinkQ
from transformers import AutoTokenizer,AutoConfig, AutoModelForCausalLM
import time
import argparse

def parse_args():
    parse = argparse.ArgumentParser()
    parse.add_argument("--model_path", type=str, default=None, help="model path")
    parse.add_argument("--model_type", type=str, default="llama", help="model type",choices=["llama","mistral"])
    parse.add_argument("--method", type=str, default='SinkQ', help="quant method",choices=["FP16","KIVI","SinkQ"])
    # quant hyper parameters
    parse.add_argument("--k_bits", type=int, default=2)
    parse.add_argument("--v_bits", type=int, default=2)
    parse.add_argument("--group_size", type=int, default=128)
    parse.add_argument("--residual_length", type=int, default=128)
    parse.add_argument("--sink_num", type=int, default=3)
    parse.add_argument("--sink_max_size", type=int, default=32)
    # test config
    parse.add_argument("--batch_size", type=int, default=16)
    parse.add_argument("--prompt_length", type=int, default=64)
    parse.add_argument("--output_length", type=int, default=384)
    parse.add_argument("--num_repeats", type=int, default=3)
    args = parse.parse_args()
    return args

def main(args):
    config =AutoConfig.from_pretrained(args.model_path)
    config.k_bits = args.k_bits
    config.v_bits = args.k_bits
    config.group_size = args.group_size
    config.residual_length = args.residual_length
    config.sink_num=args.sink_num
    config.sink_max_size=args.sink_max_size
    use_fast=args.model_type=="llama"
    tokenizer=AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True, use_fast=use_fast)
    
    if args.method=="KIVI" and args.model_type=="llama":
        model = LlamaForCausalLM_KIVI.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    elif args.method=="KIVI" and args.model_type=="mistral":
        model = MistralForCausalLM_KIVI.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    elif args.method=="SinkQ" and args.model_type=="llama":
        model = LlamaForCausalLM_SinkQ.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    elif args.method=="SinkQ" and args.model_type=="mistral":
        model = MistralForCausalLM_SinkQ.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    else:
        model = AutoModelForCausalLM.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")

    context = []
    for _ in range(args.batch_size):
        string = 't,' * args.prompt_length
        context.append(string[:-3])
    inputs = tokenizer(context, return_tensors="pt").to('cuda')
    input_ids = inputs['input_ids']
    print("="*40)
    print(f'''Decode config:  
                batch size: {args.batch_size}
                seqlen: {input_ids.shape[1]}+{args.output_length}
                model:{args.model_path}''')
    print("="*40)
    outputs = model.generate(**inputs, max_new_tokens=args.output_length)
    torch.cuda.reset_peak_memory_stats()
    with torch.no_grad():
        torch.cuda.synchronize()
        st = time.time()
        for i in range(args.num_repeats):
            outputs = model.generate(**inputs, max_new_tokens=args.output_length)
        torch.cuda.synchronize()
        avg_time=(time.time() - st) / args.num_repeats
        throughput=(args.output_length*args.batch_size)/avg_time
        used_mem = torch.cuda.max_memory_allocated()
        memory=used_mem / 1024 ** 3
    print("="*40)
    print(f'''Decode results:  
                Throughput: {throughput}
                Memory peak:{memory}''')
    print("="*40)
        
if __name__=="__main__":
    args=parse_args()
    main(args)