import torch
import numpy as np
import json
import argparse
import os
from tqdm import tqdm
import warnings

def load_adapter(adapter_dir):
    """Load LoRA weights from a single adapter directory."""
    lora_name = "adapter_model.bin"
    if not os.path.exists(os.path.join(adapter_dir, lora_name)):
        lora_name = lora_name.replace(".bin", ".safetensors")
    lora_path = os.path.join(adapter_dir, lora_name)
    if os.path.exists(lora_path):
        weights = torch.load(lora_path, map_location='cpu', weights_only=False)
    else:
        raise FileNotFoundError(f"LoRA file not found in {adapter_dir}")
    return weights

def is_effectively_zero(values, threshold=1e-10):
    """Check if values are effectively zero (accounting for floating point errors)."""
    return np.max(np.abs(values)) < threshold

def compute_singular_values_and_metrics(weights, top_k=32, filter_zero=True, zero_threshold=1e-10):
    """Compute top-k singular values and metrics for each layer's Delta W = B @ A.

    Args:
        weights: LoRA weights dictionary
        top_k: Number of top singular values to compute
        filter_zero: Whether to filter out untrained (zero) modules
        zero_threshold: Threshold for determining if a module is untrained
    """
    results = {}
    layer_weights = {}
    skipped_layers = []

    # Group up and down weights by layer
    for key, weight in weights.items():
        if "lora_B.weight" in key:
            layer_name = key.replace("lora_B.weight", "").replace("base_model.model.", "")
            if layer_name in layer_weights:
                layer_weights[layer_name]["up"] = weight
            else:
                layer_weights[layer_name] = {"up": weight}
        elif "lora_A.weight" in key:
            layer_name = key.replace("lora_A.weight", "").replace("base_model.model.", "")
            if layer_name in layer_weights:
                layer_weights[layer_name]["down"] = weight
            else:
                layer_weights[layer_name] = {"down": weight}

    pbar = tqdm(layer_weights.items())
    for layer_name, w in pbar:
        if "up" in w and "down" in w:
            up = w["up"]
            down = w["down"]
            # Compute Delta W = B @ A (up @ down)
            delta_w = up @ down
            # Flatten for SVD if necessary
            delta_w_flat = delta_w.view(-1, delta_w.size(-1)) if delta_w.dim() > 2 else delta_w
            delta_w_flat = delta_w_flat.to("cuda")
            U, S, V = torch.svd(delta_w_flat)
            singular_values = S[:top_k].cpu().numpy()  # Top-32 singular values

            # Check if this layer is effectively zero (untrained)
            if filter_zero and is_effectively_zero(singular_values, zero_threshold):
                skipped_layers.append(layer_name)
                pbar.set_description(f"Skipped {layer_name} (zero)")
                continue

            # Compute p_i
            sigma_sq = singular_values ** 2
            sum_sigma_sq = np.sum(sigma_sq)

            # Safe division with epsilon
            if sum_sigma_sq < 1e-10:
                # If sum is too small, set uniform distribution
                p_i = np.ones_like(sigma_sq) / len(sigma_sq)
                warnings.warn(f"Layer {layer_name} has near-zero singular values, using uniform p_i")
            else:
                p_i = sigma_sq / sum_sigma_sq

            # Compute S_k (cumulative energy ratios)
            S_k = np.cumsum(p_i)

            # Compute effective rank
            # Safe computation with epsilon and validation
            p_i_safe = np.where(p_i > 1e-10, p_i, 1e-10)
            entropy = -np.sum(p_i * np.log(p_i_safe))
            r_eff = np.exp(entropy)

            # Additional validation: r_eff should be between 1 and top_k
            if not (0.9 <= r_eff <= top_k + 0.1):  # Allow small floating point errors
                warnings.warn(f"Layer {layer_name} has unusual r_eff: {r_eff}")

            results[layer_name] = {
                'singular_values': singular_values.tolist(),
                'p_i': p_i.tolist(),
                'S_k': S_k.tolist(),
                'r_eff': float(r_eff)
            }
        # del up, down

    if skipped_layers:
        print(f"\nFiltered out {len(skipped_layers)} untrained/zero modules:")
        for layer in skipped_layers[:5]:
            print(f"  - {layer}")
        if len(skipped_layers) > 5:
            print(f"  ... and {len(skipped_layers) - 5} more")

    return results

def main():
    parser = argparse.ArgumentParser(description="Compute singular values and metrics for LoRA weights.")
    parser.add_argument('dir', help="Directory containing the LoRA adapter_model.bin or .safetensors.")
    parser.add_argument('--output', default=None, help="Output JSON file path.")
    parser.add_argument('--no-filter-zero', action='store_true',
                        help="Include untrained/zero modules in the output (default: filter them out)")
    parser.add_argument('--zero-threshold', type=float, default=1e-10,
                        help="Threshold for determining if a module is untrained (default: 1e-10)")
    parser.add_argument('--top-k', type=int, default=32,
                        help="Number of top singular values to compute (default: 32)")
    args = parser.parse_args()

    weights = load_adapter(args.dir)
    results = compute_singular_values_and_metrics(
        weights,
        top_k=args.top_k,
        filter_zero=not args.no_filter_zero,
        zero_threshold=args.zero_threshold
    )

    if args.output is None:
        args.output = os.path.join(args.dir, "singular_values.json")

    with open(args.output, 'w') as f:
        json.dump(results, f, indent=4)

    print(f"Results saved to {args.output}")
    print(f"Total modules processed: {len(results)}")
    if not args.no_filter_zero:
        print(f"Zero threshold used: {args.zero_threshold}")

if __name__ == "__main__":
    main()
