

import torch
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
from tqdm import tqdm
import os


def parse_args():
    parser = argparse.ArgumentParser(description="Calculate Perplexity on WikiText-2 using Sliding Window.")
    
 
    parser.add_argument("--model_path", type=str, required=True, 
                        help="Path to the model or HuggingFace repo id.")
    

    parser.add_argument("--max_length", type=int, default=2048, 
                        help="Context window size. Qwen usually supports 4k/8k/32k, but 2048 is standard for PPL bench.")
    parser.add_argument("--stride", type=int, default=2048, 
                        help="Stride for sliding window. 512 is standard.")
    
 
    parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16", "float32"],
                        help="Data type for evaluation.")
    
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    
    return parser.parse_args()

def main():
    args = parse_args()
    
  
    dtype_map = {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32
    }
    torch_dtype = dtype_map[args.dtype]

    print(f"Loading model from {args.model_path}...")
    print(f"Using dtype: {torch_dtype}")
        
    # bnb_config = BitsAndBytesConfig(
    #     load_in_4bit=True,             
    #     bnb_4bit_quant_type="nf4",       
    #     bnb_4bit_use_double_quant=True,   
    #     bnb_4bit_compute_dtype=torch.bfloat16,  
    # )

    # 2. 
    # device_map="auto" 
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            #quantization_config=bnb_config,
            torch_dtype=torch_dtype,
            device_map="auto",
            trust_remote_code=True
        )
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    model.eval()

    dataset_name = "ptb"  # : "wikitext", "ptb", "c4"

    if dataset_name == "wikitext":
        print("Loading WikiText-2...")
        test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
        raw_text_list = test["text"]
        
    elif dataset_name == "ptb":
        print("Loading PTB...")
        test = load_dataset("ptb_text_only", "penn_treebank", split="test")
        raw_text_list = test["sentence"] 
        
    elif dataset_name == "c4":
        print("Loading C4 (Validation Subset)...")
       
        ds = load_dataset("allenai/c4", "en", split="validation", streaming=True, trust_remote_code=True)
        raw_text_list = []
        for i, item in enumerate(ds):
            if i >= 2000: break 
            raw_text_list.append(item["text"])

  
    print(f"Tokenizing {len(raw_text_list)} samples...")
    encodings = tokenizer("\n\n".join(raw_text_list), return_tensors="pt")




    max_length = args.max_length
    stride = args.stride
    seq_len = encodings.input_ids.size(1)

    nlls = []
    prev_end_loc = 0
    
    print(f"Evaluating PPL (Total Tokens: {seq_len})...")
    
   
    for begin_loc in tqdm(range(0, seq_len, stride)):
        
    
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  
        
    
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
        
    
        target_ids = input_ids.clone()
      
    
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            
  
            neg_log_likelihood = outputs.loss
        
        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    
    ppl = torch.exp(torch.stack(nlls).mean())
    
    print("\n" + "="*30)
    print(f"Model: {args.model_path}")
    print(f"Perplexity (WikiText-2): {ppl.item():.2f}")
    print("="*30)

if __name__ == "__main__":
    main()





