import torch
from models.llama_kivi import LlamaForCausalLM_KIVI
from models.mistral_kivi import MistralForCausalLM_KIVI
from models.llama_sinkq import LlamaForCausalLM_SinkQ
from models.mistral_sinkq import MistralForCausalLM_SinkQ
from transformers import AutoTokenizer,AutoConfig, AutoModelForCausalLM
import argparse

def parse_args():
    parse = argparse.ArgumentParser()
    parse.add_argument("--model_path", type=str, default=None, help="model path")
    parse.add_argument("--model_type", type=str, default="llama", help="model type",choices=["llama","mistral"])
    parse.add_argument("--method", type=str, default='SinkQ', help="quant method",choices=["FP16","KIVI","SinkQ"])
    # quant hyper parameters
    parse.add_argument("--k_bits", type=int, default=2)
    parse.add_argument("--v_bits", type=int, default=2)
    parse.add_argument("--group_size", type=int, default=128)
    parse.add_argument("--residual_length", type=int, default=128)
    parse.add_argument("--sink_num", type=int, default=5)
    parse.add_argument("--sink_max_size", type=int, default=32)
    
    args = parse.parse_args()
    return args

def main(args):
    config =AutoConfig.from_pretrained(args.model_path)
    config.k_bits = args.k_bits
    config.v_bits = args.k_bits
    config.group_size = args.group_size
    config.residual_length = args.residual_length
    config.sink_num=args.sink_num
    config.sink_max_size=args.sink_max_size
    use_fast=args.model_type=="llama"
    tokenizer=AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True, use_fast=use_fast)
    
    if args.method=="KIVI" and args.model_type=="llama":
        model = LlamaForCausalLM_KIVI.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    elif args.method=="KIVI" and args.model_type=="mistral":
        model = MistralForCausalLM_KIVI.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    elif args.method=="SinkQ" and args.model_type=="llama":
        model = LlamaForCausalLM_SinkQ.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    elif args.method=="SinkQ" and args.model_type=="mistral":
        model = MistralForCausalLM_SinkQ.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
    else:
        model = AutoModelForCausalLM.from_pretrained(args.model_path,config=config,torch_dtype=torch.float16,device_map="auto")
        
    prompt='''Question: Jen and Tyler are gymnasts practicing flips. Jen is practicing the triple-flip while Tyler is practicing the double-flip. 
    Jen did sixteen triple-flips during practice. Tyler flipped in the air half the number of times Jen did. How many double-flips did Tyler do?\n
    Answer: Jen did 16 triple-flips, so she did 16 * 3 = <<16*3=48>>48 flips.\nTyler did half the number of flips, so he did 48 / 2 = <<48/2=24>>24 flips.\n
    A double flip has two flips, so Tyler did 24 / 2 = <<24/2=12>>12 double-flips.\n#### 12\n\nQuestion: Four people in a law firm are planning a party. 
    Mary will buy a platter of pasta for $20 and a loaf of bread for $2. Elle and Andrea will split the cost for buying 4 cans of soda which cost $1.50 each, 
    and chicken wings for $10. Joe will buy a cake that costs $5. How much more will Mary spend than the rest of the firm put together?\n
    Answer: Mary will spend $20 + $2 = $<<20+2=22>>22.\nElle and Andrea will spend $1.5 x 4 = $<<1.5*4=6>>6 for the soda.\n
    Elle and Andrea will spend $6 + $10 = $<<6+10=16>>16 for the soda and chicken wings.\nElle, Andrea, and Joe together will spend $16 + $5 = $<<16+5=21>>21.\n
    So, Mary will spend $22 - $21 = $<<22-21=1>>1 more than all of them combined.\n#### 1\n\nQuestion: A charcoal grill burns fifteen coals to ash 
    every twenty minutes of grilling. The grill ran for long enough to burn three bags of coals. Each bag of coal contains 60 coals. How long did the grill run?\n
    Answer: The grill burned 3 * 60 = <<3*60=180>>180 coals.\nIt takes 20 minutes to burn 15 coals, so the grill ran for 180 / 15 * 20 = <<180/15*20=240>>240 minutes.\n
    #### 240\n\nQuestion: A bear is preparing to hibernate for the winter and needs to gain 1000 pounds. At the end of summer, the bear feasts on berries and small 
    # woodland animals. During autumn, it devours acorns and salmon. It gained a fifth of the weight it needed from berries during summer, and during autumn, it gained 
    # twice that amount from acorns. Salmon made up half of the remaining weight it had needed to gain. How many pounds did it gain eating small animals?\n
    # Answer: The bear gained 1 / 5 * 1000 = <<1/5*1000=200>>200 pounds from berries.\nIt gained 2 * 200 = <<2*200=400>>400 pounds from acorns.\n
    # It still needed 1000 - 200 - 400 = <<1000-200-400=400>>400 pounds.\nThus, it gained 400 / 2 = <<400/2=200>>200 pounds from salmon.\n
    # Therefore, the bear gained 400 - 200 = <<400-200=200>>200 pounds from small animals.\n#### 200\n\nQuestion: Brendan can cut 8 yards of grass per day, he 
    # bought a lawnmower and it helped him to cut more yards by Fifty percent per day. How many yards will Brendan be able to cut after a week?\nAnswer: 
    # The additional yard Brendan can cut after buying the lawnmower is 8 x 0.50 = <<8*0.50=4>>4 yards.\nSo, the total yards he can cut with the lawnmower is 
    # 8 + 4 = <<8+4=12>>12.\nTherefore, the total number of yards he can cut in a week is 12 x 7 = <<12*7=84>>84 yards.\n#### 84\n\nQuestion: Janet’s ducks lay 
    # 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' 
    # market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nAnswer:'''
    
    inputs = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    output = model.generate(inputs, max_new_tokens=256, do_sample=False)
    print(f"===========================   input_len: {inputs.shape[1]}   ===========================")
    print(prompt)
    print(f"===========================   output_len: {len(output[0].tolist()[inputs.shape[1]:])}   ===========================")
    print(tokenizer.decode(output[0].tolist()[inputs.shape[1]:], skip_special_tokens=True))
    
if __name__ == "__main__":
    args=parse_args()
    main(args)