"""Script for analyzing LoRA adapters by computing statistics on their weights.

Usage:
    python adapter_analyze.py --adapter_path path/to/adapter
"""

import os
import argparse
import numpy as np
from safetensors.torch import load_file, save_file
import matplotlib.pyplot as plt
import shutil

target_modules = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]


def analyze_adapter(adapter_path: str):
    """Compute and display statistics for a LoRA adapter.

    Args:
        adapter_path: Path to adapter directory
    """
    # Load adapter weights
    adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors"))

    # Compute statistics using running averages
    total_params = 0
    running_sum = 0.0
    running_l1_sum = 0.0  # For L1 norm (mean of absolute values)

    # Extract layer information from keys
    layers = set()
    for key in adapter_weights.keys():
        # Extract layer number from key (format typically: model.layers.{layer_num}.{module}.lora_A.weight)
        parts = key.split(".")
        for i, part in enumerate(parts):
            if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
                layers.add(int(parts[i + 1]))

    # Sort layers
    layers = sorted(list(layers))

    # Track statistics per layer
    layer_stats = {layer: {"params": 0, "sum": 0.0, "l1_sum": 0.0} for layer in layers}

    # Also track module stats for reference
    module_stats = {
        module: {"params": 0, "sum": 0.0, "l1_sum": 0.0} for module in target_modules
    }

    for key, tensor in adapter_weights.items():
        tensor_flat = tensor.flatten().numpy()
        tensor_size = tensor_flat.size
        tensor_mean = np.mean(tensor_flat)
        tensor_l1_mean = np.mean(np.abs(tensor_flat))  # L1 norm

        # Update running statistics
        total_params += tensor_size
        running_sum += tensor_mean * tensor_size
        running_l1_sum += tensor_l1_mean * tensor_size

        # Update per-layer statistics
        parts = key.split(".")
        for i, part in enumerate(parts):
            if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
                layer = int(parts[i + 1])
                layer_stats[layer]["params"] += tensor_size
                layer_stats[layer]["sum"] += tensor_mean * tensor_size
                layer_stats[layer]["l1_sum"] += tensor_l1_mean * tensor_size

        # Update per-module statistics
        for module in target_modules:
            if module in key:
                module_stats[module]["params"] += tensor_size
                module_stats[module]["sum"] += tensor_mean * tensor_size
                module_stats[module]["l1_sum"] += tensor_l1_mean * tensor_size

    # Calculate final means
    mean = running_sum / total_params if total_params > 0 else 0.0
    l1_mean = running_l1_sum / total_params if total_params > 0 else 0.0

    # Calculate per-layer means
    for layer in layers:
        if layer_stats[layer]["params"] > 0:
            layer_stats[layer]["mean"] = (
                layer_stats[layer]["sum"] / layer_stats[layer]["params"]
            )
            layer_stats[layer]["l1_mean"] = (
                layer_stats[layer]["l1_sum"] / layer_stats[layer]["params"]
            )
        else:
            layer_stats[layer]["mean"] = 0.0
            layer_stats[layer]["l1_mean"] = 0.0

    # Calculate per-module means
    for module in target_modules:
        if module_stats[module]["params"] > 0:
            module_stats[module]["mean"] = (
                module_stats[module]["sum"] / module_stats[module]["params"]
            )
            module_stats[module]["l1_mean"] = (
                module_stats[module]["l1_sum"] / module_stats[module]["params"]
            )
        else:
            module_stats[module]["mean"] = 0.0
            module_stats[module]["l1_mean"] = 0.0

    # Print statistics
    print(f"Adapter: {adapter_path}")
    print(f"Number of parameters: {total_params:,}")
    print(f"Overall Mean: {mean:.6f}")
    print(f"Overall Mean of absolute values (L1): {l1_mean:.6f}")

    # Print per-layer statistics
    print("\nPer-layer statistics:")
    for layer in layers:
        if layer_stats[layer]["params"] > 0:
            print(f"  Layer {layer}:")
            print(f"    Parameters: {layer_stats[layer]['params']:,}")
            print(f"    Mean: {layer_stats[layer]['mean']:.6f}")
            print(f"    L1 Mean: {layer_stats[layer]['l1_mean']:.6f}")

    # Print per-module statistics
    print("\nPer-module statistics:")
    for module in target_modules:
        if module_stats[module]["params"] > 0:
            print(f"  {module}:")
            print(f"    Parameters: {module_stats[module]['params']:,}")
            print(f"    Mean: {module_stats[module]['mean']:.6f}")
            print(f"    L1 Mean: {module_stats[module]['l1_mean']:.6f}")

    return {
        "path": adapter_path,
        "num_params": total_params,
        "mean": mean,
        "l1_mean": l1_mean,
        "layer_stats": layer_stats,
        "module_stats": module_stats,
    }
    
def weight_histogram(adapter_path: str):
    adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors"))
    all_weights = []
    for key, tensor in adapter_weights.items():
        all_weights.append(tensor.flatten().numpy())
    
    # Concatenate all flattened tensors into a single array
    all_weights_combined = np.concatenate(all_weights)
    
    # Create histogram of all weights
    plt.hist(all_weights_combined, bins=100)
    plt.title("Histogram of all adapter weights")
    plt.xlabel("Weight value")
    plt.ylabel("Frequency")
    plt.savefig("weight_histogram.png")
    plt.close()
    
def truncate_weights(adapter_path: str, output_path: str, threshold: float = 0.01):
    adapter_weights = load_file(os.path.join(adapter_path, "adapter_model.safetensors"))
    all_weights = []
    for key, tensor in adapter_weights.items():
        all_weights.append(tensor.flatten().numpy())
    
    # Concatenate all flattened tensors into a single array
    all_weights_combined = np.abs(np.concatenate(all_weights))
    # Compute 1-threshold percentile
    threshold_value = np.percentile(all_weights_combined, (1 - threshold) * 100)
    print(f"Threshold value: {threshold_value}")
    # Truncate weights below threshold
    for key, tensor in adapter_weights.items():
        tensor[tensor.abs() < threshold_value] = 0.0
    
    # Copy the adapter folder to the output path
    shutil.copytree(adapter_path, output_path, dirs_exist_ok=True)  
    
    # Save truncated weights
    os.makedirs(output_path, exist_ok=True)
    save_file(adapter_weights, os.path.join(output_path, "adapter_model.safetensors"))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Analyze LoRA adapter weights")
    parser.add_argument(
        "--adapter_path", type=str, required=True, help="Path to adapter directory"
    )

    args = parser.parse_args()
    # analyze_adapter(args.adapter_path)
    weight_histogram(args.adapter_path)
    truncate_weights(args.adapter_path, "tmp/emergent-misalignment/olmo-32b-insecure-truncated-t10", 0.1)