from argparse import ArgumentParser
import argparse
import gc
from transformers import BitsAndBytesConfig, BertTokenizer, BertForSequenceClassification, BitsAndBytesConfig
import torch
from torch.profiler import profile, record_function, ProfilerActivity
import time
import logging

# Function to measure latency for different sequence lengths
def measure_latency(args, model, tokenizer, sequence_length, batch_size=1):
    # Create random input text of specified sequence length (dummy input IDs)
    dummy_input_ids = torch.randint(low=0, high=tokenizer.vocab_size, size=(batch_size, sequence_length)).to(args.device)

    # Measure the time taken for the forward pass
    start_time = time.time()
    with torch.no_grad():
        outputs = model(input_ids=dummy_input_ids)
    end_time = time.time()

    # Compute the latency
    latency = end_time - start_time
    return latency

def main(args):
    logging.basicConfig(filename=args.output_file, level=logging.INFO)
    
    # Load the tokenizer
    model_name = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(model_name)

    # Define the quantization configuration (8-bit quantization)
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=args.load_in_8bit,  # Enables 8-bit quantization
        llm_int8_threshold=6.0  # Adjust the threshold for speed vs. accuracy trade-off
    )

    # Load the quantized BERT model for sequence classification
    model = BertForSequenceClassification.from_pretrained(model_name, 
                                                        #   quantization_config=bnb_config
                                                          )

    # Move the model to GPU if available
    device = torch.device(args.device)
    model = model.to(device)
    
    for desired_token_length in range(args.min_tokens, args.max_tokens, args.step_size):
        
        latency = measure_latency(args, model, tokenizer, desired_token_length)
        
        if args.aggressive_memory:
            # free up cuda memory
            outputs = None
            input_ids = None
            target_ids = None
            gc.collect()
            torch.cuda.empty_cache()
        
        # log info
        logging.info('sequence length is : {0}, execution time is:{time_used:.3f}'.format(desired_token_length, time_used = latency))

    # profiling 
    
    # with profile(activities=[
    #     ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    #     with record_function("model_inference"):
    #         outputs = model(input_ids, labels=target_ids)

    # print(prof.key_averages().table(sort_by="cuda_time_total"))
    # print(prof.key_averages().table(sort_by="self_cuda_memory_usage"))
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max-tokens", type=int, default=30720)
    parser.add_argument("--min-tokens", type=int, default=256)
    parser.add_argument("--step-size", type=int, default=256)
    parser.add_argument("--output-file", type=str)
    parser.add_argument("--device", type=str, default='cuda')
    parser.add_argument("--device-map", type=str, default='auto')
    parser.add_argument("--aggressive-memory", action="store_true")
    parser.add_argument("--load-in-8bit", action="store_true")
    parser.add_argument("--load-in-4bit", action="store_true")
    args = parser.parse_args()
    main(args)


    