#!/usr/bin/env python3
"""
Main Experiment Script for QR-Adaptor (ACL 2026)

This script orchestrates the full experimental pipeline:
1. Quantize base models with different configurations
2. Train LoRA adapters on quantized models
3. Evaluate on multiple benchmarks
4. Record all metrics (time, memory, performance)

Usage:
    # Run motivating example (Qwen3-1.7B, 4 configs)
    python experiments/main_experiment.py --experiment motivating
    
    # Run main experiments (all models, all baselines)
    python experiments/main_experiment.py --experiment main --model qwen3-1.7b
    
    # Run full benchmark suite
    python experiments/main_experiment.py --experiment main --model all
"""

from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Any

import torch


# ========================================
# Project Root (computed from script location)
# ========================================

# Get project root: parent of 'experiments' directory
_SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = _SCRIPT_DIR.parent.parent  # qwen_lora_importance/../ = zhou/


# ========================================
# Configuration
# ========================================

@dataclass
class ModelConfig:
    """Model configuration."""
    model_id: str
    short_name: str
    num_layers: int
    hidden_size: int
    # Layer indices for shallow/deep split (for motivating example)
    shallow_layers: int = 8  # layers [0, shallow_layers) are "shallow"
    
MODEL_CONFIGS = {
    # Qwen3 Family
    "qwen3-1.7b": ModelConfig("Qwen/Qwen3-1.7B", "qwen3-1.7b", 28, 2048, 8),
    "qwen3-4b": ModelConfig("Qwen/Qwen3-4B", "qwen3-4b", 36, 2560, 10),
    "qwen3-8b": ModelConfig("Qwen/Qwen3-8B", "qwen3-8b", 36, 4096, 10),
    # LLaMA3 Family
    "llama3.2-1b": ModelConfig("meta-llama/Llama-3.2-1B", "llama3.2-1b", 16, 2048, 4),
    "llama3.2-3b": ModelConfig("meta-llama/Llama-3.2-3B", "llama3.2-3b", 28, 3072, 8),
    "llama3-8b": ModelConfig("meta-llama/Meta-Llama-3-8B", "llama3-8b", 32, 4096, 10),
}


@dataclass
class ExperimentConfig:
    """Experiment configuration."""
    name: str
    description: str
    # Quantization
    quant_bits: List[int]  # e.g., [2, 4] for mixed precision
    quant_config: Dict[str, Any] = field(default_factory=dict)
    # LoRA
    lora_rank: int = 16
    lora_alpha: int = 32
    dynamic_rank: bool = False  # For AdaLoRA
    rank_config: Dict[str, Any] = field(default_factory=dict)
    # Training
    num_epochs: int = 1
    max_steps: int = -1  # -1 means use epochs
    batch_size: int = 4
    learning_rate: float = 2e-4
    # Evaluation
    eval_tasks: List[str] = field(default_factory=lambda: ["wikitext", "mmlu"])
    eval_shots: int = 5


# ========================================
# Baseline Configurations
# ========================================

def get_baseline_configs(model_config: ModelConfig, rank: int = 16) -> Dict[str, ExperimentConfig]:
    """Get baseline experiment configurations."""
    num_layers = model_config.num_layers
    shallow = model_config.shallow_layers
    
    # AMQ: shallow layers 2-bit, deep layers 4-bit (sensitivity-based)
    # This approximates AMQ search result targeting avg ~4-bit
    amq_q = [2] * shallow + [4] * (num_layers - shallow)
    
    return {
        # Upper Bound: LoRA on FP16 base
        "lora_fp16": ExperimentConfig(
            name="lora_fp16",
            description="LoRA (r={}) on FP16 Base Model (Upper Bound)".format(rank),
            quant_bits=[16],
            lora_rank=rank,
            lora_alpha=rank * 2,
            eval_tasks=["wikitext", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
        ),
        
        # QLoRA 4-bit
        "qlora_4bit": ExperimentConfig(
            name="qlora_4bit",
            description="QLoRA: 4-bit Base + LoRA (r={})".format(rank),
            quant_bits=[4],
            lora_rank=rank,
            lora_alpha=rank * 2,
            eval_tasks=["wikitext", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
        ),
        
        # QLoRA 2-bit
        "qlora_2bit": ExperimentConfig(
            name="qlora_2bit",
            description="QLoRA (Low-Bit): 2-bit Base + LoRA (r={})".format(rank),
            quant_bits=[2],
            lora_rank=rank,
            lora_alpha=rank * 2,
            eval_tasks=["wikitext", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
        ),
        
        # QLoRA 3-bit
        "qlora_3bit": ExperimentConfig(
            name="qlora_3bit",
            description="QLoRA (Low-Bit): 3-bit Base + LoRA (r={})".format(rank),
            quant_bits=[3],
            lora_rank=rank,
            lora_alpha=rank * 2,
            eval_tasks=["wikitext", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
        ),
        
        # QLoRA 8-bit (using BitsAndBytes LLM.int8)
        "qlora_8bit": ExperimentConfig(
            name="qlora_8bit",
            description="QLoRA: 8-bit (LLM.int8) Base + LoRA (r={})".format(rank),
            quant_bits=[8],
            lora_rank=rank,
            lora_alpha=rank * 2,
            eval_tasks=["wikitext", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
        ),
        
        # AdaLoRA (4-bit base, dynamic rank targeting avg=16)
        "adalora_4bit": ExperimentConfig(
            name="adalora_4bit",
            description="AdaLoRA: 4-bit Base + Dynamic Rank (Target Avg r={})".format(rank),
            quant_bits=[4],
            lora_rank=rank * 2,  # Initial rank
            lora_alpha=rank * 2,
            dynamic_rank=True,
            rank_config={"target_rank": rank, "initial_rank": rank * 2},
            eval_tasks=["wikitext", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
        ),
        
        # AMQ + LoRA (Mixed-precision with 2/4/8-bit + dynamic rank)
        # 8-bit layers are kept as FP16 (not quantized) for max precision
        "amq_lora": ExperimentConfig(
            name="amq_lora",
            description="AMQ: Sensitivity (2/4/8-bit) + Dynamic Rank (8-bit=FP16)",
            quant_bits=[2, 4, 8],  # 8-bit kept as FP16
            lora_rank=rank,
            lora_alpha=rank * 2,
            quant_config={"run_amq_search": True, "target_bits": 4.5},
            eval_tasks=["wikitext", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
        ),

        # AMQ Baseline: Mixed-precision quantization with UNIFORM rank (no mixed rank)
        # This is the true AMQ baseline - only mixed quant, no mixed rank
        "amq_baseline_4bit": ExperimentConfig(
            name="amq_baseline_4bit",
            description="AMQ Baseline: Sensitivity-based Mixed Quant (target 4-bit) + Uniform Rank",
            quant_bits=[2, 4, 8],
            lora_rank=rank,
            lora_alpha=rank * 2,
            quant_config={"run_amq_search": True, "target_bits": 4.0, "uniform_rank": True},
            eval_tasks=["wikitext", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
        ),

        "amq_baseline_6bit": ExperimentConfig(
            name="amq_baseline_6bit",
            description="AMQ Baseline: Sensitivity-based Mixed Quant (target 6-bit) + Uniform Rank",
            quant_bits=[2, 4, 8],
            lora_rank=rank,
            lora_alpha=rank * 2,
            quant_config={"run_amq_search": True, "target_bits": 6.0, "uniform_rank": True},
            eval_tasks=["wikitext", "c4", "arc_easy", "arc_challenge", "piqa", "hellaswag", "winogrande", "mmlu"],
        ),
    }


def get_motivating_configs(model_config: ModelConfig) -> Dict[str, ExperimentConfig]:
    """Get motivating example configurations (A, B, C, D)."""
    num_layers = model_config.num_layers
    shallow = model_config.shallow_layers
    
    # Build per-layer bit-width arrays
    config_a_q = [2] * num_layers
    config_b_q = [4] * num_layers
    config_c_q = [4] * shallow + [2] * (num_layers - shallow)  # shallow high, deep low
    config_d_q = [2] * shallow + [4] * (num_layers - shallow)  # shallow low, deep high
    
    # Build per-layer rank arrays
    # Config A/B: uniform rank 16
    # Config C (Anti-Intuition): shallow R8, deep R16  
    # Config D (Ours): shallow R16, deep R8
    config_a_r = [16] * num_layers
    config_b_r = [16] * num_layers
    config_c_r = [8] * shallow + [16] * (num_layers - shallow)  # shallow low rank, deep high rank
    config_d_r = [16] * shallow + [8] * (num_layers - shallow)  # shallow high rank, deep low rank
    
    return {
        "config_A": ExperimentConfig(
            name="config_A",
            description="Config A (Uniform Low): All 2-bit + Rank 8",
            quant_bits=[2],
            quant_config={"q": config_a_q},
            lora_rank=16,
            lora_alpha=32,
            rank_config={"r": config_a_r},
            max_steps=500,
            eval_tasks=["wikitext", "mmlu"],
        ),
        
        "config_B": ExperimentConfig(
            name="config_B",
            description="Config B (Uniform High): All 4-bit + Rank 8",
            quant_bits=[4],
            quant_config={"q": config_b_q},
            lora_rank=16,
            lora_alpha=32,
            rank_config={"r": config_b_r},
            max_steps=500,
            eval_tasks=["wikitext", "mmlu"],
        ),
        
        "config_C": ExperimentConfig(
            name="config_C",
            description="Config C (Anti-Intuition): Shallow 4-bit+R8 / Deep 2-bit+R16",
            quant_bits=[2, 4],
            quant_config={"q": config_c_q},
            lora_rank=16,  # base rank
            lora_alpha=32,
            dynamic_rank=True,
            rank_config={"r": config_c_r},
            max_steps=500,
            eval_tasks=["wikitext", "mmlu"],
        ),
        
        "config_D": ExperimentConfig(
            name="config_D",
            description="Config D (Ours): Shallow 2-bit+R16 / Deep 4-bit+R8",
            quant_bits=[2, 4],
            quant_config={"q": config_d_q},
            lora_rank=16,  # base rank
            lora_alpha=32,
            dynamic_rank=True,
            rank_config={"r": config_d_r},
            max_steps=500,
            eval_tasks=["wikitext", "mmlu"],
        ),
    }


# ========================================
# Metrics Recording
# ========================================

@dataclass
class ExperimentMetrics:
    """Container for all experiment metrics."""
    # Identification
    model_name: str
    config_name: str
    timestamp: str
    
    # Time metrics
    quantization_time_sec: float = 0.0
    training_time_sec: float = 0.0
    evaluation_time_sec: float = 0.0
    total_time_sec: float = 0.0
    
    # Memory metrics (in GB)
    peak_training_memory_gb: float = 0.0
    peak_eval_memory_gb: float = 0.0
    
    # Storage metrics (in MB)
    base_model_size_mb: float = 0.0
    quantized_model_size_mb: float = 0.0
    lora_adapter_size_mb: float = 0.0
    total_model_size_mb: float = 0.0
    
    # Quantization metrics
    average_bitwidth: float = 0.0
    average_rank: float = 0.0
    per_layer_bits: List[int] = field(default_factory=list)
    per_layer_ranks: List[int] = field(default_factory=list)
    
    # Performance metrics
    wikitext2_ppl: float = float('inf')
    c4_ppl: float = float('inf')
    arc_easy_acc: float = 0.0
    arc_challenge_acc: float = 0.0
    piqa_acc: float = 0.0
    hellaswag_acc: float = 0.0
    winogrande_acc: float = 0.0
    mmlu_acc: float = 0.0
    gsm8k_acc: float = 0.0
    
    # Computed metrics
    avg_accuracy: float = 0.0
    
    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)
    
    def save(self, path: Path):
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)
    
    @classmethod
    def load(cls, path: Path) -> "ExperimentMetrics":
        with open(path, 'r') as f:
            data = json.load(f)
        return cls(**data)


def get_dir_size_mb(path: Path) -> float:
    """Get directory size in MB."""
    total = 0
    if path.is_file():
        return path.stat().st_size / (1024 * 1024)
    for f in path.rglob('*'):
        if f.is_file():
            total += f.stat().st_size
    return total / (1024 * 1024)


def reset_peak_memory():
    """Reset CUDA peak memory stats."""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()


def get_peak_memory_gb() -> float:
    """Get peak GPU memory in GB."""
    if torch.cuda.is_available():
        return torch.cuda.max_memory_allocated() / (1024 ** 3)
    return 0.0


# ========================================
# Experiment Steps
# ========================================

class ExperimentRunner:
    """Runs a single experiment configuration."""
    
    def __init__(
        self,
        model_config: ModelConfig,
        exp_config: ExperimentConfig,
        output_root: Path,
        device: str = "cuda",
    ):
        self.model_config = model_config
        self.exp_config = exp_config
        self.output_root = output_root
        self.device = device
        
        # Create output directories
        self.exp_dir = output_root / model_config.short_name / exp_config.name
        self.quant_dir = self.exp_dir / "quantized_model"
        self.ckpt_dir = self.exp_dir / "checkpoints"
        self.results_dir = self.exp_dir / "results"
        self.config_dir = self.exp_dir / "config"
        
        for d in [self.quant_dir, self.ckpt_dir, self.results_dir, self.config_dir]:
            d.mkdir(parents=True, exist_ok=True)
        
        # Initialize metrics
        self.metrics = ExperimentMetrics(
            model_name=model_config.short_name,
            config_name=exp_config.name,
            timestamp=datetime.now().isoformat(),
        )
    
    def save_config(self):
        """Save experiment configuration."""
        config_data = {
            "model": asdict(self.model_config),
            "experiment": asdict(self.exp_config),
        }
        
        # Save per-layer config if available
        if "q" in self.exp_config.quant_config:
            config_data["q"] = self.exp_config.quant_config["q"]
        if "r" in self.exp_config.rank_config:
            config_data["r"] = self.exp_config.rank_config["r"]
        
        with open(self.config_dir / "config.json", 'w') as f:
            json.dump(config_data, f, indent=2)
        
        return self.config_dir / "config.json"
    
    def step_quantize(self) -> bool:
        """Step 1: Quantize base model."""
        print(f"\n{'='*60}")
        print(f"Step 1: Quantizing {self.model_config.model_id}")
        print(f"Config: {self.exp_config.name}")
        print(f"{'='*60}")
        
        start_time = time.time()
        reset_peak_memory()
        
        try:
            # Determine quantization type
            quant_bits = self.exp_config.quant_bits
            
            if len(quant_bits) == 1 and quant_bits[0] == 16:
                # FP16: No quantization needed
                print("Skipping quantization (FP16 mode)")
                self.metrics.quantization_time_sec = 0
                self.metrics.average_bitwidth = 16.0
                return True
            
            if len(quant_bits) == 1 and quant_bits[0] == 8:
                # 8-bit: No pre-quantization needed, BitsAndBytes loads at train time
                print("Skipping pre-quantization (8-bit BitsAndBytes mode)")
                self.metrics.quantization_time_sec = 0
                self.metrics.average_bitwidth = 8.0
                return True
            
            # Check if AMQ search is needed
            if self.exp_config.quant_config.get("run_amq_search", False):
                print("Running AMQ sensitivity search...")
                target_bits = self.exp_config.quant_config.get("target_bits", 3.5)
                uniform_rank = self.exp_config.quant_config.get("uniform_rank", False)
                amq_output = self.config_dir / "amq_search_result.json"

                # Run AMQ sensitivity search
                amq_cmd = [
                    sys.executable,
                    "qwen_lora_importance/scripts/amq_sensitivity_search.py",
                    "--model_id", self.model_config.model_id,
                    "--output", str(amq_output),
                    "--bit_budget", str(target_bits),
                    "--samples", "64"  # Faster search
                ]
                # Add uniform_rank flag for AMQ baseline (no mixed rank)
                if uniform_rank:
                    amq_cmd.append("--uniform_rank")
                    amq_cmd.extend(["--base_rank", str(self.exp_config.lora_rank)])
                    print("AMQ Baseline mode: using uniform rank")
                print(f"Running: {' '.join(amq_cmd)}")
                amq_result = subprocess.run(amq_cmd, capture_output=True, text=True, cwd=PROJECT_ROOT)
                
                if amq_result.returncode != 0:
                    print(f"AMQ search failed, using fallback shallow/deep split")
                    print(amq_result.stderr)
                    # Fallback: shallow 2-bit, deep 4-bit
                    shallow = self.model_config.shallow_layers
                    num_layers = self.model_config.num_layers
                    q_array = [2] * shallow + [4] * (num_layers - shallow)
                else:
                    # Load search result
                    with open(amq_output, 'r') as f:
                        amq_data = json.load(f)
                    q_array = amq_data["q"]
                    print(f"AMQ search complete: avg_bits={amq_data['actual_avg_bits']:.2f}")
                
                # Update config with search result
                self.exp_config.quant_config["q"] = q_array
            
            # Build quantization command
            config_path = self.save_config()
            
            # Determine target bits
            if "q" in self.exp_config.quant_config:
                # Mixed precision: save config and use it
                q_array = self.exp_config.quant_config["q"]
                self.metrics.per_layer_bits = q_array
                self.metrics.average_bitwidth = sum(q_array) / len(q_array)
            else:
                # Uniform precision
                self.metrics.average_bitwidth = quant_bits[0]
                self.metrics.per_layer_bits = [quant_bits[0]] * self.model_config.num_layers
            
            # Run quantization script
            cmd = [
                sys.executable,
                "qwen_lora_importance/quantize_hqq.py",
                "--model_id", self.model_config.model_id,
                "--output_dir", str(self.quant_dir),
            ]
            
            # Determine how to specify precision
            if "q" in self.exp_config.quant_config:
                # Mixed precision: use config file
                config_path = self.save_config()
                
                # Check if we should use AMQ-OWQ (for 8-bit support) or HQQ
                q_array = self.exp_config.quant_config["q"]
                has_8bit = 8 in q_array
                
                if has_8bit and self.exp_config.name == "amq_lora":
                    # Use AMQ-OWQ for 8-bit support
                    amq_config = self.config_dir / "amq_search_result.json"
                    cmd = [
                        sys.executable,
                        "qwen_lora_importance/scripts/quantize_amq_owq.py",
                        "--model_id", self.model_config.model_id,
                        "--config_file", str(amq_config),
                        "--output_path", str(self.quant_dir),
                    ]
                    print("Using AMQ-OWQ quantization (8-bit supported)")
                else:
                    # Use HQQ for 2/3/4-bit only configurations
                    cmd = [
                        sys.executable,
                        "qwen_lora_importance/quantize_hqq.py",
                        "--model_id", self.model_config.model_id,
                        "--output_dir", str(self.quant_dir),
                        "--config_file", str(config_path)
                    ]
            else:
                # Uniform precision: use HQQ with --bits
                cmd = [
                    sys.executable,
                    "qwen_lora_importance/quantize_hqq.py",
                    "--model_id", self.model_config.model_id,
                    "--output_dir", str(self.quant_dir),
                    "--bits", str(quant_bits[0])
                ]
            
            print(f"Running: {' '.join(cmd)}")
            result = subprocess.run(cmd, capture_output=True, text=True, cwd=PROJECT_ROOT)
            
            if result.returncode != 0:
                print(f"Quantization failed:\n{result.stderr}")
                return False
            
            print(result.stdout)
            
        except Exception as e:
            print(f"Quantization error: {e}")
            return False
        
        self.metrics.quantization_time_sec = time.time() - start_time
        self.metrics.quantized_model_size_mb = get_dir_size_mb(self.quant_dir)
        
        print(f"Quantization completed in {self.metrics.quantization_time_sec:.1f}s")
        print(f"Model size: {self.metrics.quantized_model_size_mb:.1f} MB")
        
        return True
    
    def step_train(self) -> bool:
        """Step 2: Train LoRA adapter."""
        print(f"\n{'='*60}")
        print(f"Step 2: Training LoRA Adapter")
        print(f"{'='*60}")
        
        start_time = time.time()
        reset_peak_memory()
        
        try:
            # Ensure config is saved before training
            config_path = self.save_config()
            
            # Determine model path (quantized, original for FP16, or original for 8-bit)
            if self.exp_config.quant_bits == [16]:
                model_path = self.model_config.model_id
            elif self.exp_config.quant_bits == [8]:
                model_path = self.model_config.model_id  # Use original model, load with 8-bit
            else:
                model_path = str(self.quant_dir)
            
            # Record rank metrics
            if "r" in self.exp_config.rank_config:
                r_array = self.exp_config.rank_config["r"]
                self.metrics.per_layer_ranks = r_array
                self.metrics.average_rank = sum(r_array) / len(r_array)
            else:
                self.metrics.average_rank = self.exp_config.lora_rank
                self.metrics.per_layer_ranks = [self.exp_config.lora_rank] * self.model_config.num_layers
            
            # Build training command
            cmd = [
                sys.executable,
                "qwen_lora_importance/train_hqq_lora.py",
                "--hqq_model_path", model_path,
                "--qra_config", str(config_path),
                "--output_dir", str(self.ckpt_dir),
                "--num_train_epochs", str(self.exp_config.num_epochs),
                "--batch_size", str(self.exp_config.batch_size),
                "--learning_rate", str(self.exp_config.learning_rate),
            ]
            
            # Add 8-bit flag if needed
            if self.exp_config.quant_bits == [8]:
                cmd.append("--load_in_8bit")
            
            if self.exp_config.max_steps > 0:
                cmd.extend(["--max_steps", str(self.exp_config.max_steps)])
            
            print(f"Running: {' '.join(cmd)}")
            result = subprocess.run(cmd, capture_output=True, text=True, cwd=PROJECT_ROOT)
            
            if result.returncode != 0:
                print(f"Training failed:\n{result.stderr}")
                return False
            
            print(result.stdout)
            
        except Exception as e:
            print(f"Training error: {e}")
            return False
        
        self.metrics.training_time_sec = time.time() - start_time
        
        # Parse peak memory from subprocess output
        peak_memory = 0.0
        if result.stdout:
            for line in result.stdout.split('\n'):
                if 'PEAK_MEMORY_GB:' in line:
                    try:
                        peak_memory = float(line.split(':')[1].strip())
                    except:
                        pass
        self.metrics.peak_training_memory_gb = peak_memory
        self.metrics.lora_adapter_size_mb = get_dir_size_mb(self.ckpt_dir)
        
        print(f"Training completed in {self.metrics.training_time_sec:.1f}s")
        print(f"Peak memory: {self.metrics.peak_training_memory_gb:.2f} GB")
        print(f"Adapter size: {self.metrics.lora_adapter_size_mb:.1f} MB")
        
        return True
    
    def step_evaluate(self) -> bool:
        """Step 3: Evaluate on benchmarks."""
        print(f"\n{'='*60}")
        print(f"Step 3: Evaluating on Benchmarks")
        print(f"Tasks: {self.exp_config.eval_tasks}")
        print(f"{'='*60}")
        
        start_time = time.time()
        reset_peak_memory()
        
        try:
            # Determine model path
            if self.exp_config.quant_bits == [16]:
                model_path = self.model_config.model_id
            elif self.exp_config.quant_bits == [8]:
                model_path = self.model_config.model_id  # Use original, load with 8-bit
            else:
                model_path = str(self.quant_dir)
            
            # Build evaluation command
            tasks = ",".join(self.exp_config.eval_tasks)
            output_file = self.results_dir / "eval_results.json"
            
            cmd = [
                sys.executable,
                "qwen_lora_importance/eval_lm_harness.py",
                "--hqq_model_path", model_path,
                "--lora_path", str(self.ckpt_dir),
                "--output_file", str(output_file),
                "--tasks", tasks,
                "--num_fewshot", str(self.exp_config.eval_shots),
            ]
            
            # Add 8-bit flag if needed
            if self.exp_config.quant_bits == [8]:
                cmd.append("--load_in_8bit")
            
            print(f"Running: {' '.join(cmd)}")
            result = subprocess.run(cmd, capture_output=True, text=True, cwd=PROJECT_ROOT)
            
            if result.returncode != 0:
                print(f"Evaluation failed:\n{result.stderr}")
                # Try fallback evaluation
                return self._fallback_evaluate()
            
            print(result.stdout)
            
            # Parse results
            if output_file.exists():
                with open(output_file, 'r') as f:
                    eval_results = json.load(f)
                self._parse_eval_results(eval_results)
            
        except Exception as e:
            print(f"Evaluation error: {e}")
            return self._fallback_evaluate()
        
        self.metrics.evaluation_time_sec = time.time() - start_time
        
        # Parse peak memory from subprocess output
        peak_memory = 0.0
        if result.stdout:
            for line in result.stdout.split('\n'):
                if 'PEAK_MEMORY_GB:' in line:
                    try:
                        peak_memory = float(line.split(':')[1].strip())
                    except:
                        pass
        self.metrics.peak_eval_memory_gb = peak_memory
        
        print(f"Evaluation completed in {self.metrics.evaluation_time_sec:.1f}s")
        
        return True
    
    def _fallback_evaluate(self) -> bool:
        """Fallback evaluation using simple PPL and MMLU."""
        print("Using fallback evaluation...")
        try:
            if self.exp_config.quant_bits == [16]:
                model_path = self.model_config.model_id
            else:
                model_path = str(self.quant_dir)
            
            output_file = self.results_dir / "eval_results.json"
            
            cmd = [
                sys.executable,
                "qwen_lora_importance/eval_hqq_model.py",
                "--hqq_model_path", model_path,
                "--lora_path", str(self.ckpt_dir),
                "--output_file", str(output_file),
                "--max_samples", "100",
                "--eval_mmlu",
                "--mmlu_samples", "2000",
            ]
            
            result = subprocess.run(cmd, capture_output=True, text=True, cwd=PROJECT_ROOT)
            
            if result.returncode == 0 and output_file.exists():
                with open(output_file, 'r') as f:
                    eval_results = json.load(f)
                if "wikitext2_ppl" in eval_results:
                    self.metrics.wikitext2_ppl = eval_results["wikitext2_ppl"]
                if "mmlu_accuracy" in eval_results:
                    self.metrics.mmlu_acc = eval_results["mmlu_accuracy"]
                return True
        except Exception as e:
            print(f"Fallback evaluation failed: {e}")
        
        return False
    
    def _parse_eval_results(self, results: Dict):
        """Parse lm-eval-harness results."""
        task_results = results.get("results", {})
        
        # WikiText-2 PPL - handle both old and new lm-eval key formats
        if "wikitext" in task_results:
            wikitext_res = task_results["wikitext"]
            # Try new format first (lm-eval >= 0.4), then old format
            ppl = wikitext_res.get("word_perplexity,none", 
                  wikitext_res.get("word_perplexity", float('inf')))
            if ppl != float('inf') and ppl is not None:
                self.metrics.wikitext2_ppl = ppl
        
        # Other metrics - handle both old and new lm-eval key formats
        metric_mapping = {
            "arc_easy": ("arc_easy_acc", ["acc,none", "acc"]),
            "arc_challenge": ("arc_challenge_acc", ["acc_norm,none", "acc_norm"]),
            "piqa": ("piqa_acc", ["acc,none", "acc"]),
            "hellaswag": ("hellaswag_acc", ["acc_norm,none", "acc_norm"]),
            "winogrande": ("winogrande_acc", ["acc,none", "acc"]),
            "mmlu": ("mmlu_acc", ["acc,none", "acc"]),
            "gsm8k": ("gsm8k_acc", ["acc,none", "acc"]),
        }
        
        for task_name, (metric_attr, result_keys) in metric_mapping.items():
            if task_name in task_results:
                value = 0.0
                for key in result_keys:
                    if key in task_results[task_name]:
                        value = task_results[task_name][key]
                        break
                setattr(self.metrics, metric_attr, value)
        
        # Compute average accuracy
        accs = [
            self.metrics.arc_easy_acc,
            self.metrics.arc_challenge_acc,
            self.metrics.piqa_acc,
            self.metrics.hellaswag_acc,
            self.metrics.winogrande_acc,
            self.metrics.mmlu_acc,
        ]
        valid_accs = [a for a in accs if a > 0]
        if valid_accs:
            self.metrics.avg_accuracy = sum(valid_accs) / len(valid_accs)
    
    def run(self) -> ExperimentMetrics:
        """Run the full experiment pipeline."""
        print(f"\n{'#'*70}")
        print(f"# Experiment: {self.exp_config.name}")
        print(f"# Model: {self.model_config.model_id}")
        print(f"# Description: {self.exp_config.description}")
        print(f"{'#'*70}")
        
        total_start = time.time()
        
        # Step 1: Quantize
        if not self.step_quantize():
            print("Quantization failed, skipping remaining steps")
            return self.metrics
        
        # Step 2: Train
        if not self.step_train():
            print("Training failed, skipping evaluation")
            return self.metrics
        
        # Step 3: Evaluate
        self.step_evaluate()
        
        # Finalize metrics
        self.metrics.total_time_sec = time.time() - total_start
        self.metrics.total_model_size_mb = (
            self.metrics.quantized_model_size_mb + self.metrics.lora_adapter_size_mb
        )
        
        # Save metrics
        self.metrics.save(self.results_dir / "metrics.json")
        
        # Print summary
        self._print_summary()
        
        return self.metrics
    
    def _print_summary(self):
        """Print experiment summary."""
        print(f"\n{'='*60}")
        print(f"Experiment Summary: {self.exp_config.name}")
        print(f"{'='*60}")
        print(f"Time:")
        print(f"  Quantization: {self.metrics.quantization_time_sec:.1f}s")
        print(f"  Training: {self.metrics.training_time_sec:.1f}s")
        print(f"  Evaluation: {self.metrics.evaluation_time_sec:.1f}s")
        print(f"  Total: {self.metrics.total_time_sec:.1f}s")
        print(f"\nMemory:")
        print(f"  Peak Training: {self.metrics.peak_training_memory_gb:.2f} GB")
        print(f"  Peak Eval: {self.metrics.peak_eval_memory_gb:.2f} GB")
        print(f"\nStorage:")
        print(f"  Quantized Model: {self.metrics.quantized_model_size_mb:.1f} MB")
        print(f"  LoRA Adapter: {self.metrics.lora_adapter_size_mb:.1f} MB")
        print(f"  Total: {self.metrics.total_model_size_mb:.1f} MB")
        print(f"\nConfiguration:")
        print(f"  Avg Bit-width: {self.metrics.average_bitwidth:.2f}")
        print(f"  Avg Rank: {self.metrics.average_rank:.2f}")
        print(f"\nPerformance:")
        print(f"  WikiText-2 PPL: {self.metrics.wikitext2_ppl:.2f}")
        print(f"  MMLU Acc: {self.metrics.mmlu_acc:.4f}")
        print(f"  Avg Accuracy: {self.metrics.avg_accuracy:.4f}")
        print(f"{'='*60}")


# ========================================
# Experiment Orchestration
# ========================================

def run_motivating_example(
    model_name: str = "qwen3-1.7b",
    output_root: Path = None,
    configs: Optional[List[str]] = None,
):
    """Run motivating example experiment."""
    if output_root is None:
        output_root = Path("qwen_lora_importance/experiments/outputs/motivating")
    
    model_config = MODEL_CONFIGS[model_name]
    exp_configs = get_motivating_configs(model_config)
    
    if configs:
        exp_configs = {k: v for k, v in exp_configs.items() if k in configs}
    
    results = {}
    for name, exp_config in exp_configs.items():
        runner = ExperimentRunner(model_config, exp_config, output_root)
        metrics = runner.run()
        results[name] = metrics.to_dict()
    
    # Save aggregated results
    with open(output_root / model_config.short_name / "all_results.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    # Print comparison table
    print_comparison_table(results)
    
    return results


def run_main_experiment(
    model_name: str = "qwen3-1.7b",
    output_root: Path = None,
    baselines: Optional[List[str]] = None,
    rank: int = 16,
):
    """Run main experiment with all baselines."""
    if output_root is None:
        output_root = Path("qwen_lora_importance/experiments/outputs/main")
    
    model_config = MODEL_CONFIGS[model_name]
    baseline_configs = get_baseline_configs(model_config, rank=rank)
    
    if baselines:
        baseline_configs = {k: v for k, v in baseline_configs.items() if k in baselines}
    
    results = {}
    for name, exp_config in baseline_configs.items():
        runner = ExperimentRunner(model_config, exp_config, output_root)
        metrics = runner.run()
        results[name] = metrics.to_dict()
    
    # Save aggregated results
    with open(output_root / model_config.short_name / "all_results.json", 'w') as f:
        json.dump(results, f, indent=2)
    
    # Print comparison table
    print_comparison_table(results)
    
    return results


def print_comparison_table(results: Dict[str, Dict]):
    """Print a comparison table of results."""
    print(f"\n{'='*100}")
    print("COMPARISON TABLE")
    print(f"{'='*100}")
    
    headers = ["Config", "Avg Bits", "Avg Rank", "PPL↓", "MMLU↑", "Avg Acc↑", "Memory(GB)", "Size(MB)", "Time(s)"]
    print(f"{'Config':<20} {'Bits':>8} {'Rank':>8} {'PPL':>10} {'MMLU':>8} {'AvgAcc':>8} {'Mem(GB)':>10} {'Size(MB)':>10} {'Time(s)':>10}")
    print("-" * 100)
    
    for name, metrics in results.items():
        print(f"{name:<20} "
              f"{metrics.get('average_bitwidth', 0):>8.2f} "
              f"{metrics.get('average_rank', 0):>8.1f} "
              f"{metrics.get('wikitext2_ppl', float('inf')):>10.2f} "
              f"{metrics.get('mmlu_acc', 0):>8.4f} "
              f"{metrics.get('avg_accuracy', 0):>8.4f} "
              f"{metrics.get('peak_training_memory_gb', 0):>10.2f} "
              f"{metrics.get('total_model_size_mb', 0):>10.1f} "
              f"{metrics.get('total_time_sec', 0):>10.1f}")
    
    print(f"{'='*100}")


# ========================================
# Main Entry Point
# ========================================

def main():
    parser = argparse.ArgumentParser(
        description="QR-Adaptor Main Experiment Script",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Run motivating example (4 configs on Qwen3-1.7B)
  python main_experiment.py --experiment motivating
  
  # Run specific config
  python main_experiment.py --experiment motivating --configs config_D
  
  # Run main baselines on specific model
  python main_experiment.py --experiment main --model qwen3-4b
  
  # Run specific baselines
  python main_experiment.py --experiment main --baselines qlora_4bit adalora_4bit
        """
    )
    
    parser.add_argument(
        "--experiment",
        type=str,
        choices=["motivating", "main", "full"],
        default="motivating",
        help="Experiment type to run"
    )
    
    parser.add_argument(
        "--model",
        type=str,
        choices=list(MODEL_CONFIGS.keys()) + ["all"],
        default="qwen3-1.7b",
        help="Model to experiment on"
    )
    
    parser.add_argument(
        "--configs",
        type=str,
        nargs="+",
        help="Specific configs to run (for motivating example)"
    )
    
    parser.add_argument(
        "--baselines",
        type=str,
        nargs="+",
        help="Specific baselines to run (for main experiment)"
    )
    
    parser.add_argument(
        "--output_root",
        type=str,
        default=None,
        help="Output directory root"
    )
    
    parser.add_argument(
        "--rank",
        type=int,
        default=16,
        help="Base LoRA rank for main experiments"
    )
    
    args = parser.parse_args()
    
    output_root = Path(args.output_root) if args.output_root else None
    
    if args.experiment == "motivating":
        if args.model == "all":
            for model_name in MODEL_CONFIGS.keys():
                run_motivating_example(model_name, output_root, args.configs)
        else:
            run_motivating_example(args.model, output_root, args.configs)
    
    elif args.experiment == "main":
        if args.model == "all":
            for model_name in MODEL_CONFIGS.keys():
                run_main_experiment(model_name, output_root, args.baselines, args.rank)
        else:
            run_main_experiment(args.model, output_root, args.baselines, args.rank)
    
    elif args.experiment == "full":
        # Run both motivating and main experiments
        if args.model == "all":
            for model_name in MODEL_CONFIGS.keys():
                run_motivating_example(model_name, output_root, args.configs)
                run_main_experiment(model_name, output_root, args.baselines, args.rank)
        else:
            run_motivating_example(args.model, output_root, args.configs)
            run_main_experiment(args.model, output_root, args.baselines, args.rank)
    
    print("\n" + "="*70)
    print("ALL EXPERIMENTS COMPLETED!")
    print("="*70)


if __name__ == "__main__":
    main()
