#!/usr/bin/env python3
"""
QR-Adaptor AMQ Search: Mixed Precision with 4/8-bit + Dynamic Rank

Uses PPL perturbation for sensitivity and allocates bits based on layer importance.
8-bit layers are kept as FP16 (not quantized) for maximum precision.

Supports configurable bit budget for higher precision (e.g., 6-7 bits average).
"""

import argparse
import json
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import numpy as np


def compute_ppl_loss(model, tokenizer, texts, device="cuda", max_length=256):
    """Compute average loss on given texts."""
    model.eval()
    total_loss = 0
    total_samples = 0
    
    with torch.no_grad():
        for text in texts:
            if len(text.strip()) < 20:
                continue
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            try:
                outputs = model(**inputs, labels=inputs["input_ids"])
                loss = outputs.loss.item()
                if not np.isnan(loss) and not np.isinf(loss):
                    total_loss += loss
                    total_samples += 1
            except:
                continue
    
    return total_loss / max(total_samples, 1)


def add_layer_noise(layer, scale=0.3):
    """Add noise to simulate quantization."""
    original_weights = {}
    with torch.no_grad():
        for name, param in layer.named_parameters():
            if param.numel() > 0:
                original_weights[name] = param.data.clone()
                std = param.abs().mean().item() * scale
                noise = torch.randn_like(param) * std
                param.data.add_(noise)
    return original_weights


def restore_weights(layer, original_weights):
    """Restore original weights."""
    with torch.no_grad():
        for name, param in layer.named_parameters():
            if name in original_weights:
                param.data.copy_(original_weights[name])


def compute_dynamic_allocation(num_layers, bit_budget, sensitivities):
    """
    Dynamically compute layer allocation based on bit budget.
    
    For 4/8-bit mixed precision:
    - bit_budget = 8: all 8-bit (FP16)
    - bit_budget = 6: ~50% 8-bit, ~50% 4-bit
    - bit_budget = 4: all 4-bit
    
    The actual allocation is based on sensitivity ranking.
    """
    # For 4/8-bit only allocation:
    # avg_bits = n_8bit * 8 + n_4bit * 4 / num_layers
    # n_8bit * 8 + (num_layers - n_8bit) * 4 = bit_budget * num_layers
    # 4 * n_8bit = (bit_budget - 4) * num_layers
    # n_8bit = (bit_budget - 4) * num_layers / 4
    
    if bit_budget >= 8:
        n_8bit = num_layers
    elif bit_budget <= 4:
        n_8bit = 0
    else:
        # Linear interpolation between 4-bit and 8-bit
        n_8bit = int(round((bit_budget - 4) * num_layers / 4))
    
    n_8bit = max(0, min(num_layers, n_8bit))
    n_4bit = num_layers - n_8bit
    
    return n_8bit, n_4bit


def run_amq_search(model_id, output_path, bit_budget=6.0, n_calibration=32, uniform_rank=False, base_rank=16):
    """
    AMQ Search with 4/8-bit + Dynamic Rank.

    8-bit layers are kept as FP16 (not quantized by HQQ).

    Dynamic allocation based on bit_budget:
    - bit_budget=4: all 4-bit
    - bit_budget=6: ~50% 8-bit, ~50% 4-bit
    - bit_budget=7: ~75% 8-bit, ~25% 4-bit
    - bit_budget=8: all 8-bit (FP16)

    Rank allocation (when uniform_rank=False):
    - 8-bit layers → r=32 (higher rank for FP16)
    - 4-bit layers → r=16 (moderate rank)

    When uniform_rank=True (AMQ baseline):
    - All layers use the same rank (base_rank)
    """
    print("=" * 70)
    print("QR-Adaptor AMQ Search: 4/8-bit + Dynamic Rank")
    print("  Note: 8-bit = FP16 (no quantization)")
    print("=" * 70)
    print(f"Model: {model_id}")
    print(f"Target bit budget: {bit_budget:.1f}")
    print("Bit options: [4, 8]")
    print("=" * 70)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load model
    print("\nLoading model...")
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    model = AutoModelForCausalLM.from_pretrained(
        model_id, 
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    # Get layers
    if hasattr(model, 'model') and hasattr(model.model, 'layers'):
        layers = model.model.layers
    elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
        layers = model.transformer.h
    else:
        raise ValueError("Unsupported architecture")
    
    num_layers = len(layers)
    print(f"Number of layers: {num_layers}")
    
    # Load calibration data
    print("\nLoading calibration data...")
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    calibration_texts = [t for t in dataset["text"] if len(t) > 100][:n_calibration]
    
    # Baseline
    print("\nMeasuring baseline loss...")
    baseline_loss = compute_ppl_loss(model, tokenizer, calibration_texts[:8], device)
    print(f"Baseline loss: {baseline_loss:.4f}")
    
    # Sensitivity
    print("\n" + "=" * 70)
    print("Measuring per-layer sensitivity...")
    print("=" * 70)
    
    sensitivities = []
    for layer_idx in tqdm(range(num_layers), desc="Analyzing layers"):
        layer = layers[layer_idx]
        original_weights = add_layer_noise(layer, scale=0.3)
        perturbed_loss = compute_ppl_loss(model, tokenizer, calibration_texts[:8], device)
        restore_weights(layer, original_weights)
        sensitivity = max(0, perturbed_loss - baseline_loss)
        sensitivities.append(sensitivity)
        print(f"  Layer {layer_idx:2d}: sensitivity = {sensitivity:.6f}")
    
    # Normalize
    sens_array = np.array(sensitivities)
    if sens_array.max() > sens_array.min():
        sens_normalized = (sens_array - sens_array.min()) / (sens_array.max() - sens_array.min())
    else:
        sens_normalized = np.linspace(0, 1, num_layers)
    
    sorted_indices = np.argsort(sens_normalized)[::-1]  # Most sensitive first
    
    # Dynamic allocation based on bit budget
    n_8bit, n_4bit = compute_dynamic_allocation(num_layers, bit_budget, sensitivities)
    
    print("\n" + "=" * 70)
    print(f"Target bit budget: {bit_budget:.1f}")
    if uniform_rank:
        print(f"Allocation: {n_8bit} @ 8-bit(FP16), {n_4bit} @ 4-bit | Uniform rank r={base_rank}")
    else:
        print(f"Allocation: {n_8bit} @ 8-bit(FP16)/r=32, {n_4bit} @ 4-bit/r=16")
    print("=" * 70)

    q_array = [4] * num_layers

    if uniform_rank:
        # AMQ Baseline: mixed quant + uniform rank
        r_array = [base_rank] * num_layers
    else:
        # QR-Adapter: mixed quant + mixed rank
        r_array = [16] * num_layers

    # Most sensitive → 8-bit (FP16)
    for i in range(n_8bit):
        idx = sorted_indices[i]
        q_array[idx] = 8
        if not uniform_rank:
            r_array[idx] = 32  # Only set higher rank if not uniform

    # Rest are 4-bit (already default)
    
    avg_bits = sum(q_array) / len(q_array)
    avg_rank = sum(r_array) / len(r_array)
    
    result = {
        "model_id": model_id,
        "num_layers": num_layers,
        "bit_options": [4, 8],
        "target_bit_budget": bit_budget,
        "actual_avg_bits": avg_bits,
        "actual_avg_rank": avg_rank,
        "sensitivity_method": "ppl_perturbation",
        "baseline_loss": baseline_loss,
        "sensitivities_raw": sensitivities,
        "sensitivities_normalized": sens_normalized.tolist(),
        "q": q_array,
        "r": r_array,
        "bit_distribution": {
            "4bit": q_array.count(4),
            "8bit": q_array.count(8)
        },
        "rank_distribution": {
            "r=16": r_array.count(16),
            "r=32": r_array.count(32)
        }
    }
    
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, 'w') as f:
        json.dump(result, f, indent=2)
    
    print("\n" + "=" * 70)
    print("AMQ Search Complete!")
    print("=" * 70)
    print(f"Actual avg bits: {avg_bits:.2f}")
    print(f"Avg rank: {avg_rank:.2f}")
    print(f"Distribution: {result['bit_distribution']}")
    print(f"Saved to: {output_path}")
    
    print("\nFinal allocation:")
    for i in range(num_layers):
        bits = q_array[i]
        rank = r_array[i]
        sens = sensitivities[i]
        marker = "★" if bits == 8 else " "
        print(f"  {marker} Layer {i:2d}: {bits}-bit, r={rank} (sens={sens:.4f})")
    
    return result


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, required=True)
    parser.add_argument("--output", type=str, required=True)
    parser.add_argument("--bit_budget", type=float, default=6.0,
                        help="Target average bits (4-8). Default: 6.0")
    parser.add_argument("--samples", type=int, default=32)
    parser.add_argument("--uniform_rank", action="store_true",
                        help="Use uniform rank (AMQ baseline, no mixed rank)")
    parser.add_argument("--base_rank", type=int, default=16,
                        help="Base rank for uniform_rank mode. Default: 16")
    args = parser.parse_args()

    run_amq_search(args.model_id, args.output, args.bit_budget, args.samples,
                   uniform_rank=args.uniform_rank, base_rank=args.base_rank)


if __name__ == "__main__":
    main()
