#!/usr/bin/env python3
"""
Run AMQ's OWQ quantization on Qwen3 model with a predefined bit allocation.

This uses the real GPTQ-style OWQ quantizer from AMQ.
Imports are done directly to bypass AMQ's __init__.py which has import issues.
"""

import argparse
import json
import os
import sys
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Add AMQ quantization directly to path (bypass __init__.py)
amq_path = Path(__file__).resolve().parent.parent / "amq" / "amq"
sys.path.insert(0, str(amq_path))

# Direct imports to bypass AMQ's __init__.py
from quantization.base import BASE, get_owq_calib_dataset
from quantization.owq import OWQ


def main():
    parser = argparse.ArgumentParser(description="Run AMQ OWQ quantization")
    parser.add_argument("--model_id", type=str, required=True, help="HuggingFace model ID")
    parser.add_argument("--config_file", type=str, required=True, help="Path to JSON config with q array")
    parser.add_argument("--output_path", type=str, required=True, help="Output directory")
    parser.add_argument("--group_size", type=int, default=128, help="Quantization group size")
    parser.add_argument("--n_samples", type=int, default=128, help="Number of calibration samples")
    parser.add_argument("--seqlen", type=int, default=2048, help="Calibration sequence length")
    
    args = parser.parse_args()
    
    # Load config
    with open(args.config_file) as f:
        config = json.load(f)
    
    q_array = config["q"]
    num_layers = len(q_array)
    avg_bits = sum(q_array) / len(q_array)
    
    print("=" * 60)
    print("AMQ OWQ Quantization (GPTQ-style)")
    print("=" * 60)
    print(f"Model: {args.model_id}")
    print(f"Layers: {num_layers}")
    print(f"Avg bits: {avg_bits:.2f}")
    print(f"Bit distribution: {config.get('bit_distribution', 'N/A')}")
    print(f"Output: {args.output_path}")
    print("=" * 60)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Load model and tokenizer
    print("\nLoading model...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    # Build architecture dict for OWQ
    # OWQ expects arch['linear'][linear_name][layer_idx] = bits
    linear_names = ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", 
                    "self_attn.o_proj", "mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"]
    
    arch = {"linear": {}}
    for name in linear_names:
        arch["linear"][name] = q_array.copy()
    
    # Load config for Qwen3
    config_path = amq_path / "configs" / "qwen3.json"
    if config_path.exists():
        with open(config_path) as f:
            model_config = json.load(f)
        
        # Determine which config to use based on model
        if "1.7" in args.model_id.lower() or "1.7b" in args.model_id.lower():
            qwen_config = model_config.get("Qwen3-1.7B", model_config.get("Qwen3-4B"))
        elif "4b" in args.model_id.lower():
            qwen_config = model_config["Qwen3-4B"]
        else:
            qwen_config = model_config["Qwen3-1.7B"]  # fallback
    else:
        qwen_config = None
    
    # Get calibration dataset
    print("\nLoading calibration dataset...")
    samples = get_owq_calib_dataset(
        data='wikitext2',
        tokenizer=tokenizer,
        n_samples=args.n_samples,
        seqlen=args.seqlen
    )
    
    # Run OWQ quantization
    print("\nRunning OWQ quantization...")
    owq = OWQ(
        model=model,
        tokenizer=tokenizer,
        method='owq',
        arch=arch,
        avg_bits=avg_bits,
        group_size=args.group_size,
        config=qwen_config,
        dev=device
    )
    
    quantizers = owq.run(
        samples=samples,
        n_samples=args.n_samples,
        seqlen=args.seqlen,
        true_sequential=True,
        percdamp=0.01,
        act_order=False,
        static_groups=False,
        no_frob_norm=False,
        nsamples=args.n_samples
    )
    
    # Save quantized model
    output_path = Path(args.output_path)
    output_path.mkdir(parents=True, exist_ok=True)
    
    print(f"\nSaving quantized model to {output_path}...")
    model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)
    
    # Save quantization info
    quant_info = {
        "model_id": args.model_id,
        "method": "owq",
        "q": q_array,
        "avg_bits": avg_bits,
        "group_size": args.group_size,
        "n_samples": args.n_samples,
        "seqlen": args.seqlen,
    }
    with open(output_path / "quantization_config.json", "w") as f:
        json.dump(quant_info, f, indent=2)
    
    print("\n" + "=" * 60)
    print("OWQ Quantization Complete!")
    print("=" * 60)


if __name__ == "__main__":
    main()
