from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
import obf_ours, obf_groupcover, obf_arrowcloak
import model_arrowcloak, model_groupcover, model_ours
import torch
import argparse
import time
import gc
from tqdm import tqdm


def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Test CausalLM model with custom optimization")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model")
    parser.add_argument("--token_length", type=int, default=512, help="Length of input tokens to mock")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for inference")
    parser.add_argument("--logits_to_keep", type=int, default=1, help="Number of logits to keep for OTP")
    parser.add_argument("--block_size", type=int, default=4, help="Block size for obfuscation")
    parser.add_argument("--optimized_stage", type=int, default=2, help="Optimization stage (-1, 0, 1, 2, or 3)")
    parser.add_argument("--model_type", type=str, default="casual", help="Model type: 'casual' or 'classification'")
    parser.add_argument("--num_labels", type=int, default=2, help="Number of labels for classification")
    parser.add_argument("--num_iterations", type=int, default=1000, help="Number of iterations for accuracy testing")
    parser.add_argument("--method", type=str, default="ours", help="Obfuscation method (ours, groupcover, arrowcloak)")
    args = parser.parse_args()
    
    # Load model and tokenizer
    if args.model_type == "casual":
        original_model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True)
    elif args.model_type == "classification":
        original_model = AutoModelForSequenceClassification.from_pretrained(args.model_path, num_labels=args.num_labels, trust_remote_code=True)
    else:
        raise ValueError("model_type must be 'casual' or 'classification'")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)

    original_model.config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    
    # Mock input data
    def mock_inputs(batch_size, token_length):
        # Create random token IDs
        input_ids = torch.randint(
            low=500,  # Skip special tokens
            high=tokenizer.vocab_size - 500,  # Skip special tokens
            size=(batch_size, token_length),
            dtype=torch.long
        )
        
        # Create attention mask (all 1s since we're not padding)
        attention_mask = torch.ones_like(input_ids)
        
        return {"input_ids": input_ids, "attention_mask": attention_mask}
    
    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.model_type == "casual":
        original_model.config.logits_to_keep = args.logits_to_keep
    
    # Pre-generate all mock inputs
    print(f"\n=== Pre-generating {args.num_iterations} Inputs ===")
    all_inputs = []
    for i in tqdm(range(args.num_iterations), desc="Generating inputs"):
        inputs = mock_inputs(args.batch_size, args.token_length)
        if args.model_type == "casual":
            inputs["logits_to_keep"] = args.logits_to_keep
        all_inputs.append(inputs)
    
    # Set models to evaluation mode
    original_model.eval()
    # Run original model on all inputs and save outputs
    print(f"\n=== Running Original Model on All Inputs ===")
    original_outputs_list = []
    with torch.no_grad():
        original_model = original_model.to(device)
        for i, inputs in tqdm(enumerate(all_inputs), desc="Running original model", total=args.num_iterations):
            inputs_device = {}
            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    inputs_device[k] = v.to(device)
                else:
                    inputs_device[k] = v
            original_outputs = original_model(**inputs_device)
            _, original_indices = torch.max(original_outputs.logits, dim=-1)
            original_outputs_list.append(original_indices.cpu())
    
    # Obfuscate the model
    if args.method == "ours":
        if args.model_type == "casual":
            obf_result = obf_ours.obfuscate_model(original_model, block_size=args.block_size, obf_score=False)
        elif args.model_type == "classification":
            obf_result = obf_ours.obfuscate_model(original_model, block_size=args.block_size, obf_score=True)
    elif args.method == "groupcover":
        obf_result = obf_groupcover.obfuscate_model(original_model)
    elif args.method == "arrowcloak":
        obf_result = obf_arrowcloak.obfuscate_model(original_model)
    else:
        raise ValueError("method must be 'ours', 'groupcover', or 'arrowcloak'")
    
    if args.method == "ours":
        v_list0 = obf_result["v_list0"]
        indices_list0 = obf_result["indices_list0"]
        v_list = obf_result["v_list"]
        indices_list = obf_result["indices_list"]
        
        original_model = original_model.to("cpu")
        # Convert the obfuscated model to custom model
        custom_model = model_ours.convert_to_custom_model(original_model, v_list0, indices_list0, v_list, indices_list,
                                            optimized_stage=args.optimized_stage, simulate=False)
    elif args.method == "groupcover":
        original_model = original_model.to("cpu")
        # Convert the obfuscated model to custom model
        custom_model = model_groupcover.convert_to_custom_model(original_model, obf_result, simulate=False)
    elif args.method == "arrowcloak":
        original_model = original_model.to("cpu")
        # Convert the obfuscated model to custom model
        custom_model = model_arrowcloak.convert_to_custom_model(original_model, obf_result, simulate=False)
    else:
        raise ValueError("Invalid method")
    
    custom_model.eval()
    
    # Recycle original model to free GPU memory
    print(f"\n=== Recycling Original Model ===")
    del original_model
    torch.cuda.empty_cache()
    gc.collect()
    
    # Initialize accuracy tracking
    total_correct = 0
    total_tokens = 0
    
    # Run custom model on each input and compare with saved results
    custom_model = custom_model.to(device)
    print(f"\n=== Testing Custom Model Accuracy over {args.num_iterations} Iterations ===")
    for i, (inputs, original_indices) in tqdm(enumerate(zip(all_inputs, original_outputs_list)), desc="Testing iterations", total=args.num_iterations):
        inputs_device = {}
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs_device[k] = v.to(device)
            else:
                inputs_device[k] = v
        
        with torch.no_grad():
            # Run custom model
            custom_outputs = custom_model(**inputs_device)
        
        # Get predictions
        _, custom_indices = torch.max(custom_outputs.logits, dim=-1)
        
        # Calculate correct predictions
        print(custom_indices, original_indices)
        correct = torch.sum(custom_indices.cpu() == original_indices)
        total_correct += correct.item()
        total_tokens += original_indices.numel()
        
        # Calculate current accuracy
        current_accuracy = correct.item() / original_indices.numel()
        tqdm.write(f"Iteration {i+1}: Accuracy = {current_accuracy:.4f}")
        
        # Clear cache to save memory
        torch.cuda.empty_cache()
        gc.collect()
    
    # Calculate final accuracy
    final_accuracy = total_correct / total_tokens
    print(f"\n=== Final Results ===")
    print(f"Total iterations: {args.num_iterations}")
    print(f"Total tokens: {total_tokens}")
    print(f"Total correct: {total_correct}")
    print(f"Final accuracy: {final_accuracy*100:.4f}%")
    

if __name__ == "__main__":
    main()