from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import obf_arrowcloak, obf_ours, obf_groupcover
import model_ours, model_arrowcloak, model_groupcover
import torch
import argparse
import time

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Test CausalLM model with custom optimization")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model")
    parser.add_argument("--token_length", type=int, default=512, help="Length of input tokens to mock")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
    parser.add_argument("--block_size", type=int, default=4, help="Block size for obfuscation")
    parser.add_argument("--optimized_stage", type=int, default=2, help="Optimization stage (-1, 0, 1, 2, or 3)")
    parser.add_argument("--method", type=str, default="ours", help="Obfuscation method (ours, groupcover, arrowcloak)")
    args = parser.parse_args()
    
    # Load model and tokenizer
    original_model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)

    original_model.config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    
    # Mock input data
    def mock_inputs(batch_size, token_length):
        # Create random token IDs
        input_ids = torch.randint(
            low=500,  # Skip special tokens
            high=tokenizer.vocab_size - 500,  # Skip special tokens
            size=(batch_size, token_length),
            dtype=torch.long
        )
        
        # Create attention mask (all 1s since we're not padding)
        attention_mask = torch.ones_like(input_ids)
        
        return {"input_ids": input_ids, "attention_mask": attention_mask}
    
    # Generate mock inputs
    inputs = mock_inputs(args.batch_size, args.token_length)
    
    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    original_model = original_model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Test original model
    original_model.eval()
    with torch.no_grad():
        _ = original_model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=5,
                do_sample=False,
                pad_token_id=original_model.config.pad_token_id
            )
        torch.cuda.synchronize()
        start_time = time.perf_counter()
        generated = original_model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=50,
                do_sample=False,
                pad_token_id=original_model.config.pad_token_id,
            )
        torch.cuda.synchronize()
        end_time = time.perf_counter()
        generated_tokens = generated.numel() - inputs['input_ids'].numel()
        print(f"----------------------------------------------------\nOriginal model GPU(forward): {generated_tokens / (end_time - start_time):.6f} tokens/s\n----------------------------------------------------")
    
    if args.method == "ours":
        # Obfuscate the model
        obf_result = obf_ours.obfuscate_model(original_model, block_size=args.block_size, obf_score=False)
        
        v_list0 = obf_result["v_list0"]
        indices_list0 = obf_result["indices_list0"]
        v_list = obf_result["v_list"]
        indices_list = obf_result["indices_list"]
        
        
        original_model = original_model.to("cpu")
        # Convert the obfuscated model to custom model
        custom_model = model_ours.convert_to_custom_model(original_model, v_list0, indices_list0, v_list, indices_list,
                                            optimized_stage=args.optimized_stage, simulate=False)
    elif args.method == "groupcover":
        obf_result = obf_groupcover.obfuscate_model(original_model)

        original_model = original_model.to("cpu")
        # Convert the obfuscated model to custom model
        custom_model = model_groupcover.convert_to_custom_model(original_model, obf_result, simulate=False)
    elif args.method == "arrowcloak":
        obf_result = obf_arrowcloak.obfuscate_model(original_model)

        original_model = original_model.to("cpu")
        # Convert the obfuscated model to custom model
        custom_model = model_arrowcloak.convert_to_custom_model(original_model, obf_result, simulate=False)
    else:
        raise ValueError("Invalid method")
    
    custom_model = custom_model.to(device)
    
    # Test custom model and measure time to first token
    custom_model.eval()
    with torch.no_grad():
        _ = custom_model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=5,
                do_sample=False,
                pad_token_id=custom_model.config.pad_token_id
            )
        torch.cuda.synchronize()
        start_time = time.perf_counter()
        generated = custom_model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=50,
                do_sample=False,
                pad_token_id=custom_model.config.pad_token_id
            )
        torch.cuda.synchronize()
        end_time = time.perf_counter()
        generated_tokens = generated.numel() - inputs['input_ids'].numel()
        print(f"----------------------------------------------------\Custom model GPU(forward): {generated_tokens / (end_time - start_time):.6f} tokens/s\n----------------------------------------------------")

if __name__ == "__main__":
    main()