#!/usr/bin/env python3
"""
Train LoRA on HQQ quantized Qwen3-1.7B models.
Loads HQQ quantized model and applies LoRA fine-tuning.

Usage:
    python train_hqq_lora.py \
        --hqq_model_path experiments/motivating_example/quantized_models/config_D_ours \
        --qra_config experiments/motivating_example/configs/config_D_ours.json \
        --output_dir experiments/motivating_example/checkpoints/config_D_ours
"""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path

import torch
from datasets import load_dataset, Dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoTokenizer, Trainer, TrainingArguments

# Add project root to path for imports
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))


def load_hqq_model(model_path: str, load_in_8bit: bool = False):
    """Load HQQ quantized model, standard HF model, or FP16 model from HuggingFace hub.
    
    Args:
        model_path: Path to model or HuggingFace model ID
        load_in_8bit: If True, load with BitsAndBytes 8-bit quantization
    """
    from pathlib import Path
    from transformers import AutoModelForCausalLM, BitsAndBytesConfig
    
    # Check if this is a HuggingFace hub model (contains / and doesn't exist locally as directory)
    is_hub_model = "/" in model_path and not Path(model_path).exists()
    
    if is_hub_model:
        if load_in_8bit:
            # Load from HuggingFace hub with 8-bit BitsAndBytes quantization
            print(f"Loading 8-bit model from HuggingFace Hub: {model_path}")
            bnb_config = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=6.0,
            )
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True
            )
        else:
            # Load from HuggingFace hub as FP16 model
            print(f"Loading FP16 model from HuggingFace Hub: {model_path}")
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )
        return model
    
    # For local paths, try different formats
    model_path = Path(model_path)
    
    # Check if it's a standard HF model (has model.safetensors or pytorch_model.bin)
    has_safetensors = (model_path / "model.safetensors").exists()
    has_pytorch = (model_path / "pytorch_model.bin").exists()
    has_sharded = any(model_path.glob("model-*.safetensors"))
    
    if has_safetensors or has_pytorch or has_sharded:
        # Try loading as standard HF model first (for OWQ quantized models)
        try:
            print(f"Loading standard HF model from: {model_path}")
            model = AutoModelForCausalLM.from_pretrained(
                str(model_path),
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )
            return model
        except Exception as e:
            print(f"Failed to load as HF model: {e}")
    
    # Try HQQ format as fallback
    try:
        from hqq.models.hf.base import AutoHQQHFModel
        print(f"Loading HQQ model from: {model_path}")
        model = AutoHQQHFModel.from_quantized(str(model_path))
        return model
    except Exception as e:
        raise RuntimeError(f"Failed to load model from {model_path}: {e}")


def get_target_modules():
    """Get target modules for LoRA."""
    return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]


def load_qra_config(config_path: str):
    """Load QRA config for per-layer rank pattern."""
    with open(config_path, 'r') as f:
        return json.load(f)


def prepare_dataset(tokenizer, dataset_name="alpaca", max_samples=None, max_length=512):
    """Prepare dataset for training.
    
    Args:
        tokenizer: HuggingFace tokenizer
        dataset_name: "alpaca" or "gsm8k"
        max_samples: Optional limit on training samples
        max_length: Maximum sequence length
    """
    from data_loaders.loaders import load_dataset_for_training
    
    # Load formatted data
    data = load_dataset_for_training(dataset_name, "train", max_samples)
    
    # Convert to HuggingFace Dataset
    dataset = Dataset.from_list(data)
    
    def tokenize(example):
        result = tokenizer(
            example["text"],
            truncation=True,
            max_length=max_length,
            padding="max_length",
        )
        result["labels"] = result["input_ids"].copy()
        return result
    
    dataset = dataset.map(tokenize, remove_columns=dataset.column_names)
    return dataset


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--hqq_model_path", type=str, required=True)
    ap.add_argument("--qra_config", type=str, required=True)
    ap.add_argument("--output_dir", type=str, required=True)
    ap.add_argument("--num_train_epochs", type=int, default=1)
    ap.add_argument("--batch_size", type=int, default=4)
    ap.add_argument("--max_steps", type=int, default=500)
    ap.add_argument("--learning_rate", type=float, default=2e-4)
    ap.add_argument("--dataset", type=str, default="alpaca", choices=["alpaca", "gsm8k"],
                   help="Training dataset: alpaca (general) or gsm8k (math)")
    ap.add_argument("--max_samples", type=int, default=None,
                   help="Maximum training samples (None for full dataset)")
    ap.add_argument("--load_in_8bit", action="store_true",
                   help="Load model in 8-bit using BitsAndBytes LLM.int8")
    args = ap.parse_args()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"=" * 60)
    print(f"Training HQQ + LoRA Model")
    print(f"HQQ Model: {args.hqq_model_path}")
    print(f"QRA Config: {args.qra_config}")
    print(f"Output: {args.output_dir}")
    print(f"=" * 60)

    # Load tokenizer
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(args.hqq_model_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load HQQ model
    model = load_hqq_model(args.hqq_model_path, load_in_8bit=args.load_in_8bit)
    
    # Prepare for k-bit training
    model = prepare_model_for_kbit_training(model)

    # Load QRA config for rank pattern
    qra_config = load_qra_config(args.qra_config)
    base_rank = qra_config.get("base_r", 8)
    
    # Get rank pattern from config
    r_array = qra_config.get("r", [])
    
    # Build per-layer rank pattern for PEFT
    # PEFT supports rank_pattern: Dict[str, int] where keys are module names
    target_modules = get_target_modules()
    num_layers = len(r_array) if r_array else 28
    
    if r_array and len(set(r_array)) > 1:
        # Dynamic rank: different ranks for different layers
        print(f"Using per-layer dynamic rank pattern!")
        print(f"  Rank distribution: {dict(zip(range(num_layers), r_array))}")
        
        # Find the most common rank as base, and create pattern for exceptions
        from collections import Counter
        rank_counts = Counter(r_array)
        base_rank = rank_counts.most_common(1)[0][0]
        
        # Build rank_pattern dict: maps module name -> rank
        # Format: "model.layers.{layer_idx}.{module_name}" for each module
        rank_pattern = {}
        alpha_pattern = {}
        
        for layer_idx, layer_rank in enumerate(r_array):
            if layer_rank != base_rank:
                for module in target_modules:
                    key = f"model.layers.{layer_idx}.self_attn.{module}" if module in ["q_proj", "k_proj", "v_proj", "o_proj"] else f"model.layers.{layer_idx}.mlp.{module}"
                    rank_pattern[key] = layer_rank
                    alpha_pattern[key] = layer_rank * 2
        
        print(f"  Base rank: {base_rank}")
        print(f"  Layers with different rank: {len(rank_pattern) // len(target_modules)}")
        
        # Configure LoRA with rank_pattern
        lora_config = LoraConfig(
            r=base_rank,
            lora_alpha=base_rank * 2,
            target_modules=target_modules,
            rank_pattern=rank_pattern,
            alpha_pattern=alpha_pattern,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )
    else:
        # Uniform rank
        avg_rank = int(sum(r_array) / len(r_array)) if r_array else base_rank
        print(f"Using uniform rank: {avg_rank}")
        
        lora_config = LoraConfig(
            r=avg_rank,
            lora_alpha=avg_rank * 2,
            target_modules=target_modules,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
        )

    print(f"Applying LoRA...")
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    # Prepare dataset
    print("Preparing dataset...")
    train_dataset = prepare_dataset(tokenizer, args.dataset, args.max_samples)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=str(output_dir),
        num_train_epochs=args.num_train_epochs,
        max_steps=args.max_steps,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=4,
        learning_rate=args.learning_rate,
        fp16=True,
        logging_steps=10,
        save_steps=100,
        save_total_limit=2,
        report_to="none",
        remove_unused_columns=False,
    )

    # Train
    print("Starting training...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
    )

    trainer.train()

    # Track peak memory
    peak_memory_gb = 0.0
    if torch.cuda.is_available():
        peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)

    # Save
    print(f"Saving model to {args.output_dir}...")
    model.save_pretrained(str(output_dir))
    tokenizer.save_pretrained(str(output_dir))

    print(f"\n{'=' * 60}")
    print(f"Training complete!")
    print(f"Saved to: {args.output_dir}")
    print(f"PEAK_MEMORY_GB: {peak_memory_gb:.3f}")
    print(f"{'=' * 60}")


if __name__ == "__main__":
    main()
