from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import obf_ours, obf_groupcover, obf_arrowcloak
import model_arrowcloak, model_groupcover, model_ours
import torch
import argparse
import time
import gc


def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Test CausalLM model with custom optimization")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model")
    parser.add_argument("--token_length", type=int, default=512, help="Length of input tokens to mock")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
    parser.add_argument("--logits_to_keep", type=int, default=1, help="Number of logits to keep for OTP")
    parser.add_argument("--block_size", type=int, default=4, help="Block size for obfuscation")
    parser.add_argument("--optimized_stage", type=int, default=2, help="Optimization stage (-1, 0, 1, 2, or 3)")
    parser.add_argument("--model_type", type=str, default="casual", help="Model type: 'casual' or 'classification'")
    parser.add_argument("--num_labels", type=int, default=2, help="Number of labels for classification")
    parser.add_argument("--method", type=str, default="ours", help="Obfuscation method (ours, groupcover, arrowcloak)")
    args = parser.parse_args()
    
    # Load model and tokenizer
    if args.model_type == "casual":
        original_model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
    elif args.model_type == "classification":
        original_model = AutoModelForSequenceClassification.from_pretrained(args.model_path, num_labels=args.num_labels, trust_remote_code=True)
    else:
        raise ValueError("model_type must be 'casual' or 'classification'")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)

    original_model.config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    
    # Mock input data
    def mock_inputs(batch_size, token_length):
        # Create random token IDs
        input_ids = torch.randint(
            low=500,  # Skip special tokens
            high=tokenizer.vocab_size - 500,  # Skip special tokens
            size=(batch_size, token_length),
            dtype=torch.long
        )
        
        # Create attention mask (all 1s since we're not padding)
        attention_mask = torch.ones_like(input_ids)
        
        return {"input_ids": input_ids, "attention_mask": attention_mask}
    
    # Generate mock inputs
    inputs = mock_inputs(args.batch_size, args.token_length)
    
    original_model.eval()
    with torch.no_grad():
        _ = original_model(**inputs)
        start_time = time.perf_counter()
        original_outputs = original_model(**inputs)
        end_time = time.perf_counter()
        print(f"----------------------------------------------------\nOriginal model CPU(forward): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
    
    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    original_model = original_model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    if args.model_type == "casual":
        inputs["logits_to_keep"] = args.logits_to_keep
    
    # Test original model
    original_model.eval()
    with torch.no_grad():
        _ = original_model(**inputs)
        torch.cuda.synchronize()
        start_time = time.perf_counter()
        original_outputs = original_model(**inputs)
        torch.cuda.synchronize()
        end_time = time.perf_counter()
        print(f"----------------------------------------------------\nOriginal model GPU(forward): {(end_time - start_time) * 1000:.6f} milliseconds\n----------------------------------------------------")
    
    # Obfuscate the model
    if args.method == "ours":
        if args.model_type == "casual":
            obf_result = obf_ours.obfuscate_model(original_model, block_size=args.block_size, obf_score=False)
        elif args.model_type == "classification":
            obf_result = obf_ours.obfuscate_model(original_model, block_size=args.block_size, obf_score=True)
    elif args.method == "groupcover":
        obf_result = obf_groupcover.obfuscate_model(original_model)
    elif args.method == "arrowcloak":
        obf_result = obf_arrowcloak.obfuscate_model(original_model)
    else:
        raise ValueError("method must be 'ours', 'groupcover', or 'arrowcloak'")
    
    if args.method == "ours":
        v_list0 = obf_result["v_list0"]
        indices_list0 = obf_result["indices_list0"]
        v_list = obf_result["v_list"]
        indices_list = obf_result["indices_list"]
        
        original_model = original_model.to("cpu")
        # Convert the obfuscated model to custom model
        custom_model = model_ours.convert_to_custom_model(original_model, v_list0, indices_list0, v_list, indices_list,
                                            optimized_stage=args.optimized_stage, simulate=False)
    elif args.method == "groupcover":
        original_model = original_model.to("cpu")
        # You can get mae result of groupcover and arrowcloak in simulate mode.
        # In sgx mode, our implementation sacrificed precision for performance.
        custom_model = model_groupcover.convert_to_custom_model(original_model, obf_result, simulate=False)
    elif args.method == "arrowcloak":
        original_model = original_model.to("cpu")
        # You can get mae result of groupcover and arrowcloak in simulate mode.
        # In sgx mode, our implementation sacrificed precision for performance.
        custom_model = model_arrowcloak.convert_to_custom_model(original_model, obf_result, simulate=False)
    else:
        raise ValueError("Invalid method")
    
    custom_model = custom_model.to(device)
    
    # Test custom model and measure time to first token
    custom_model.eval()
    with torch.no_grad():
        _ = custom_model(**inputs)
        print("\n\n")
        custom_outputs = custom_model(**inputs)
        
    # Compare logits
    print("\n=== Comparing Model Outputs ===")
    
    if args.model_type == "classification" or args.method != "ours":
        print(f"logits difference: {torch.mean(torch.abs(original_outputs.logits - custom_outputs.logits))}\n")
    
    _, original_indices = torch.max(original_outputs.logits, dim=-1)
    _, custom_indices = torch.max(custom_outputs.logits, dim=-1)
    
    print(f"Expected output:\n{original_indices}\n")
    print(f"Secure inference output:\n{custom_indices}\n")
    
    indices_diff = torch.sum(original_indices != custom_indices)
    print(f"Indices difference: {indices_diff.item()}")
    

if __name__ == "__main__":
    main()