#!/usr/bin/env python3
"""
Quantize LLM models with HQQ using per-layer mixed precision.
Supports Qwen3 and LLaMA3 families with flexible config files.

Usage:
    # Using predefined config names
    python quantize_hqq.py --config config_D --output_dir quantized_models/config_D_ours
    
    # Using JSON config file
    python quantize_hqq.py --config_file path/to/config.json --output_dir quantized_models/custom
    
    # Uniform quantization
    python quantize_hqq.py --bits 4 --output_dir quantized_models/uniform_4bit
"""

from __future__ import annotations

import argparse
import copy
import json
import os
from pathlib import Path
from typing import Dict, List, Optional, Any
from tqdm import tqdm

import torch


# ============================================================
# Model Configurations
# ============================================================

MODEL_LAYER_ATTR = {
    # Qwen3 family
    "qwen3": "model.layers",
    "qwen2": "model.layers",
    # LLaMA family  
    "llama": "model.layers",
    # Mistral
    "mistral": "model.layers",
}

def get_model_layers(model):
    """Get the list of transformer layers from the model."""
    # Try common patterns
    if hasattr(model, 'model') and hasattr(model.model, 'layers'):
        return model.model.layers
    elif hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
        return model.transformer.h
    elif hasattr(model, 'gpt_neox') and hasattr(model.gpt_neox, 'layers'):
        return model.gpt_neox.layers
    else:
        raise ValueError("Unknown model architecture - cannot find transformer layers")


def get_num_layers(model_id: str) -> int:
    """Get number of layers for a model by reading from HuggingFace config."""
    from transformers import AutoConfig
    
    try:
        config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        # Different models use different attribute names
        if hasattr(config, 'num_hidden_layers'):
            return config.num_hidden_layers
        elif hasattr(config, 'n_layer'):
            return config.n_layer
        elif hasattr(config, 'num_layers'):
            return config.num_layers
        else:
            print(f"Warning: Could not detect num_layers for {model_id}, using default 32")
            return 32
    except Exception as e:
        print(f"Warning: Failed to load config for {model_id}: {e}")
        return 32


def get_shallow_split(model_id: str) -> int:
    """Get shallow/deep layer split point for a model (~30% of layers)."""
    num_layers = get_num_layers(model_id)
    # Approximately 30% of layers are "shallow"
    return max(4, int(num_layers * 0.3))


# ============================================================
# Config Generation
# ============================================================

def generate_config(config_name: str, model_id: str) -> Dict[str, Any]:
    """Generate a quantization config for the given model."""
    num_layers = get_num_layers(model_id)
    shallow = get_shallow_split(model_id)
    
    if config_name == "config_A":
        # Uniform 2-bit
        return {
            "name": "config_A",
            "description": "Uniform 2-bit + Rank 8 (all layers)",
            "q": [2] * num_layers,
            "r": [8] * num_layers,
        }
    
    elif config_name == "config_B":
        # Uniform 4-bit
        return {
            "name": "config_B",
            "description": "Uniform 4-bit + Rank 8 (all layers)",
            "q": [4] * num_layers,
            "r": [8] * num_layers,
        }
    
    elif config_name == "config_C":
        # Anti-intuition: shallow high-bit, deep low-bit
        return {
            "name": "config_C",
            "description": f"Anti-Intuition: Shallow (0-{shallow-1}) 4-bit+R8, Deep ({shallow}-{num_layers-1}) 2-bit+R16",
            "q": [4] * shallow + [2] * (num_layers - shallow),
            "r": [8] * shallow + [16] * (num_layers - shallow),
        }
    
    elif config_name == "config_D":
        # Our method: shallow low-bit, deep high-bit
        return {
            "name": "config_D",
            "description": f"Ours: Shallow (0-{shallow-1}) 2-bit+R8, Deep ({shallow}-{num_layers-1}) 4-bit+R16",
            "q": [2] * shallow + [4] * (num_layers - shallow),
            "r": [8] * shallow + [16] * (num_layers - shallow),
        }
    
    elif config_name == "qlora_4bit":
        # QLoRA 4-bit baseline
        return {
            "name": "qlora_4bit",
            "description": "QLoRA: Uniform 4-bit + Rank 16",
            "q": [4] * num_layers,
            "r": [16] * num_layers,
        }
    
    elif config_name == "qlora_2bit":
        # QLoRA 2-bit baseline
        return {
            "name": "qlora_2bit",
            "description": "QLoRA: Uniform 2-bit + Rank 16",
            "q": [2] * num_layers,
            "r": [16] * num_layers,
        }
    
    elif config_name == "qlora_3bit":
        # QLoRA 3-bit baseline
        return {
            "name": "qlora_3bit",
            "description": "QLoRA: Uniform 3-bit + Rank 16",
            "q": [3] * num_layers,
            "r": [16] * num_layers,
        }
    
    else:
        raise ValueError(f"Unknown config: {config_name}")


def load_config(config_path: str) -> Dict[str, Any]:
    """Load config from JSON file."""
    with open(config_path, 'r') as f:
        return json.load(f)


# ============================================================
# Quantization
# ============================================================

def manual_quantize_model(
    model,
    q_array: List[int],
    device: str = 'cuda',
    group_size: int = 32,
) -> torch.nn.Module:
    """Manually quantize model layers using HQQLinear with per-layer bit-widths."""
    from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
    
    layers = get_model_layers(model)
    num_layers = len(layers)
    
    if len(q_array) != num_layers:
        raise ValueError(f"q_array length ({len(q_array)}) != model layers ({num_layers})")
    
    print(f"Manually quantizing {num_layers} layers...")
    
    # Pre-create configs for each bit-width
    quant_configs = {}
    for bits in set(q_array):
        quant_configs[bits] = BaseQuantizeConfig(nbits=bits, group_size=group_size)
    
    # Track stats
    stats = {f"{b}bit": 0 for b in set(q_array)}
    
    # Quantize transformer layers
    for layer_idx in tqdm(range(num_layers), desc="Quantizing layers"):
        layer = layers[layer_idx]
        bits = q_array[layer_idx]
        stats[f"{bits}bit"] += 1
        
        # Skip quantization for 8-bit layers (keep as FP16 for max precision)
        # HQQ only supports 2, 3, 4 bit quantization
        if bits >= 8:
            print(f"  Layer {layer_idx}: Keeping FP16 (8-bit = no quantization)")
            # Just move layer to device, don't quantize
            layer.to(device)
            continue
        
        # Quantize all Linear modules in this layer
        for name, module in list(layer.named_modules()):
            if isinstance(module, torch.nn.Linear):
                # Navigate to parent module
                parent = layer
                parts = name.split('.')
                for part in parts[:-1]:
                    parent = getattr(parent, part)
                attr_name = parts[-1] if parts else name
                
                if attr_name and hasattr(parent, attr_name):
                    linear = getattr(parent, attr_name)
                    if isinstance(linear, torch.nn.Linear):
                        # Move to device and quantize
                        linear = linear.to(device)
                        hqq_linear = HQQLinear(
                            linear,
                            quant_config=copy.deepcopy(quant_configs[bits]),
                            compute_dtype=torch.float16,
                            device=device,
                        )
                        setattr(parent, attr_name, hqq_linear)
    
    # Quantize lm_head with 4-bit (always use 4-bit for output)
    if hasattr(model, 'lm_head') and isinstance(model.lm_head, torch.nn.Linear):
        lm_head_bits = 4
        if lm_head_bits not in quant_configs:
            quant_configs[lm_head_bits] = BaseQuantizeConfig(nbits=lm_head_bits, group_size=group_size)
        
        model.lm_head = model.lm_head.to(device)
        model.lm_head = HQQLinear(
            model.lm_head,
            quant_config=copy.deepcopy(quant_configs[lm_head_bits]),
            compute_dtype=torch.float16,
            device=device,
        )
    
    # Move embeddings and normalization to device
    if hasattr(model, 'model'):
        if hasattr(model.model, 'embed_tokens'):
            model.model.embed_tokens = model.model.embed_tokens.to(device)
        if hasattr(model.model, 'norm'):
            model.model.norm = model.model.norm.to(device)
        if hasattr(model.model, 'rotary_emb'):
            model.model.rotary_emb = model.model.rotary_emb.to(device)
    
    # Print stats
    print(f"Quantization stats: {stats}")
    avg_bits = sum(q_array) / len(q_array)
    print(f"Average bit-width: {avg_bits:.2f}")
    
    return model


def compute_model_size_mb(q_array: List[int], params_per_layer: int = 50_000_000) -> float:
    """Estimate quantized model size in MB."""
    total_bits = sum(bits * params_per_layer for bits in q_array)
    return total_bits / 8 / (1024 * 1024)


# ============================================================
# Main
# ============================================================

def main():
    parser = argparse.ArgumentParser(
        description="Quantize LLM with HQQ using per-layer mixed precision",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    
    parser.add_argument(
        "--model_id", type=str, default="Qwen/Qwen3-1.7B",
        help="Hugging Face model ID"
    )
    parser.add_argument(
        "--config", type=str, default=None,
        choices=["config_A", "config_B", "config_C", "config_D",
                 "qlora_4bit", "qlora_2bit", "qlora_3bit"],
        help="Predefined config name"
    )
    parser.add_argument(
        "--config_file", type=str, default=None,
        help="Path to JSON config file (overrides --config)"
    )
    parser.add_argument(
        "--bits", type=int, default=None,
        help="Uniform bit-width for all layers (overrides --config)"
    )
    parser.add_argument(
        "--output_dir", type=str, required=True,
        help="Output directory for quantized model"
    )
    parser.add_argument(
        "--group_size", type=int, default=32,
        help="Quantization group size"
    )
    parser.add_argument(
        "--device", type=str, default="cuda",
        help="Device to use for quantization"
    )
    
    args = parser.parse_args()
    
    # Determine configuration
    if args.config_file:
        config_data = load_config(args.config_file)
        config_name = config_data.get("name", "custom")
    elif args.bits:
        num_layers = get_num_layers(args.model_id)
        config_data = {
            "name": f"uniform_{args.bits}bit",
            "description": f"Uniform {args.bits}-bit quantization",
            "q": [args.bits] * num_layers,
            "r": [16] * num_layers,
        }
        config_name = config_data["name"]
    elif args.config:
        config_data = generate_config(args.config, args.model_id)
        config_name = args.config
    else:
        raise ValueError("Must specify --config, --config_file, or --bits")
    
    q_array = config_data["q"]
    
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print("=" * 60)
    print(f"HQQ Per-Layer Quantization")
    print("=" * 60)
    print(f"Model: {args.model_id}")
    print(f"Config: {config_name}")
    print(f"Description: {config_data.get('description', 'N/A')}")
    print(f"Output: {args.output_dir}")
    print(f"Bit-width distribution: {set(q_array)}")
    print(f"Average bit-width: {sum(q_array)/len(q_array):.2f}")
    print("=" * 60)
    
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from hqq.models.hf.base import AutoHQQHFModel
    
    # Load tokenizer
    print("\nLoading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load model in FP16 on CPU first
    print("Loading model in FP16...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )
    
    # Quantize
    print("\nQuantizing model...")
    model = manual_quantize_model(
        model,
        q_array=q_array,
        device=args.device,
        group_size=args.group_size,
    )
    
    # Save quantized model
    print(f"\nSaving quantized model to {args.output_dir}...")
    AutoHQQHFModel.save_quantized(model, save_dir=str(output_dir))
    tokenizer.save_pretrained(str(output_dir))
    
    # Save config info
    config_info = {
        "model_id": args.model_id,
        "config_name": config_name,
        "description": config_data.get("description", ""),
        "quantizer": "hqq_manual",
        "q": q_array,
        "r": config_data.get("r", [16] * len(q_array)),
        "avg_bits": sum(q_array) / len(q_array),
        "group_size": args.group_size,
    }
    with open(output_dir / "quantization_config.json", "w") as f:
        json.dump(config_info, f, indent=2)
    
    print(f"\n{'=' * 60}")
    print(f"Quantization complete!")
    print(f"Saved to: {args.output_dir}")
    print(f"{'=' * 60}")


if __name__ == "__main__":
    main()
