from argparse import ArgumentParser
import argparse
import gc
from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, QuantoConfig, AwqConfig, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
from scaled_rope.patch import *
from transformers import AutoModelForCausalLM
import transformers
import datasets
import torch
from torch.profiler import profile, record_function, ProfilerActivity
import time
import logging
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
def ntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None, scale = 8):
    #The method is just these three lines
    max_position_embeddings = 16384
    a = scale #Alpha value
    base = base * a ** (dim / (dim-2)) #Base change formula

    old_init(self, dim, max_position_embeddings, base, device)
def load_model_and_apply_patches(model_path, args):
    # if args.load_in_4bit or args.load_in_8bit:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=args.load_in_4bit,
        load_in_8bit=args.load_in_8bit,
        llm_int8_threshold=6.0,
        llm_int8_has_fp16_weight=False,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )
    # else:
    # quantization_config = None
    # if args.awq:
    #     quantization_config = AwqConfig(bits = 4)
    # else:   
    #     if args.quant_weight_bit_width not in ['float8','int8','int4','int2']:
    #         raise ValueError("quantization bitwidth is invalid, please choose from ['float8','int8','int4','int2']") 
    #     else:    
    #         quantization_config = QuantoConfig(weights = args.quant_weight_bit_width)
            
    """ #### Apply NTK-Scaled Init patch"""
    transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init
    
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=args.map_device,
                                                 quantization_config=quantization_config,
                                                #  attn_implementation="flash_attention_2"
                                                 )
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    return model, tokenizer

def load_govtext_data():
    
    device = 'cuda'
    # load prompt data
    input_texts = datasets.load_from_disk('/home/yeq6/Research_project/llama/output/govreport-test-tokenized')


    input_texts = input_texts.filter(lambda x: x["tokenized_len"] >= 2048 and x["tokenized_len"] < 4096 )
    # print(input_texts)
    input_texts = input_texts[0:1]

    # print(input_texts)

    encoded_texts = input_texts["input_ids"]
    attn_masks = input_texts["attention_mask"]

    # print(encoded_texts)

    labels = torch.tensor(encoded_texts[0:1])
    seq_len = labels.size(1)
    print('seq_len: ', seq_len)
    # print(labels)
    input_ids = labels.to(device)
    target_ids = input_ids.clone()
    # target_ids[:, :-trg_len] = -10s0

    print('input shape: ', input_ids.shape, input_ids.dtype)
    return input_ids, target_ids


def measure_latency(args, model, tokenizer, sequence_length, batch_size=1):
    # Create random input tensor with the given sequence length
    dummy_input_ids = torch.randint(low=0, high=tokenizer.vocab_size, size=(batch_size, sequence_length)).to(args.device)

    # Measure the time taken for the forward pass
    start_time = time.time()
    with torch.no_grad():
        outputs = model(input_ids=dummy_input_ids)
    end_time = time.time()

    # Compute the latency
    latency = end_time - start_time
    return latency

def main(args):
    logging.basicConfig(filename=args.output_file, level=logging.INFO)
    # model_path = "/home/yeq6/Research_project/llama/llama-2-7b-chat_hf"
    # model, tokenizer = load_model_and_apply_patches(model_path, args)
    
    # Load the tokenizer and model
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=args.load_in_4bit,
        load_in_8bit=args.load_in_8bit,
        llm_int8_threshold=6.0,
        llm_int8_has_fp16_weight=False,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )
    if args.model == 'llama2_7b':
        model_name = "/home/yeq6/Research_project/llama/llama-2-7b-chat_hf"  # LLaMA 2 model on Hugging Face hub
        tokenizer = LlamaTokenizer.from_pretrained(model_name)
        model = LlamaForCausalLM.from_pretrained(model_name)
        model = model.to(args.device)
    elif args.model == 'phi_1.5':
        model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5")
        tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
        model = model.to(args.device)
    
    # input_ids, target_ids = load_govtext_data()
    
    for desired_token_length in range(args.min_tokens, args.max_tokens, args.step_size):
        # Specify the desired token length
        
        # desired_token_length = 10  # Adjust this to your needs

        # Generate random token IDs within the tokenizer's vocabulary range
        # random_input_ids = torch.randint(0, tokenizer.vocab_size, (1, desired_token_length))

        # # Create attention mask (1s for real tokens, 0s for padding)
        # attention_mask = torch.ones_like(random_input_ids)

        # # Prepare inputs
        # inputs = {
        #     'input_ids': random_input_ids,
        #     'attention_mask': attention_mask
        # }


        # with torch.no_grad():
        #     start_time = time.time()
        #     outputs = model(**inputs)
        #     time_used = time.time()-start_time
        
        
        # with torch.no_grad():
        #     start_time = time.time()
        #     outputs = model(input_ids, labels=target_ids)
        #     time_used = time.time()-start_time
        # # print(outputs)
        
        latency = measure_latency(args, model, tokenizer, desired_token_length)
        
        if args.aggressive_memory:
            # free up cuda memory
            outputs = None
            input_ids = None
            target_ids = None
            gc.collect()
            torch.cuda.empty_cache()
        
        # log info
        logging.info('sequence length is : {0}, execution time is:{time_used:.3f}'.format(desired_token_length, time_used = latency))

    # profiling 
    
    # with profile(activities=[
    #     ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    #     with record_function("model_inference"):
    #         outputs = model(input_ids, labels=target_ids)

    # print(prof.key_averages().table(sort_by="cuda_time_total"))
    # print(prof.key_averages().table(sort_by="self_cuda_memory_usage"))
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max-tokens", type=int, default=30720)
    parser.add_argument("--min-tokens", type=int, default=256)
    parser.add_argument("--step-size", type=int, default=256)
    parser.add_argument("--output-file", type=str)
    parser.add_argument("--model", type=str, default='llama2_7b')
    parser.add_argument("--device", type=str, default='cuda')
    parser.add_argument("--device-map", type=str, default='auto')
    parser.add_argument("--aggressive-memory", action="store_true")
    parser.add_argument("--load-in-8bit", action="store_true")
    parser.add_argument("--load-in-4bit", action="store_true")
    args = parser.parse_args()
    main(args)


    