# scripts/config.py

import os
from pathlib import Path
import torch

# ==============================================================================
#  SECTION 1: EXPERIMENT CONTROL PANEL (USER EDITABLE)
# ==============================================================================

# --- A. Scenario Toggles ---
# Controls which experimental pipelines to run.
# - Proxy: Uses external WebDatasets (e.g., CC3M) for calibration/QAT.
# - Noise: Uses synthetic Gaussian noise (data-free calibration).
# - Real ID: Uses a split of the target domain's training set (e.g., ImageNet Train).
ENABLE_PROXY_EXPERIMENTS   = True   
ENABLE_NOISE_EXPERIMENTS   = False   
ENABLE_REAL_ID_EXPERIMENTS = False   
ENABLE_SPURIOUS_BENCHMARK = True

# --- B. Performance & Logic Toggles (NEW) ---
# Use Automatic Mixed Precision (FP16/BF16) for faster training/inference
USE_AMP = True 
# FIXED: Explicit AMP dtype to ensure fair comparison (BFloat16 vs Float16 matters for quantization)
AMP_DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
AMP_DEVICE_TYPE = 'cuda' if torch.cuda.is_available() else 'cpu'
# If False, the teacher model is NOT loaded into memory. 
# Distillation losses will be forced to 0.0. Saves ~50% VRAM.
ENABLE_TEACHER = True 

# If False, skips the Phase 2 Logit Scale Tuning step.
ENABLE_LOGIT_TUNING = True
# --- B. Proxy Dataset Selection ---
# If ENABLE_PROXY_EXPERIMENTS is True, the script iterates over these keys.
# Supported Shards (defined in PROXY_DATASETS below):
# - "CC3M"      : Conceptual Captions 3M (Balanced, general purpose)
# - "YFCC"      : YFCC100M subset (Noisy, wild internet data)
# - "DataComp"  : DataComp-1B Best subset (High quality filtered)
# - "SBU"       : SBU Captions (Older, smaller dataset)
ACTIVE_PROXY_DATASETS = [
    "CC3M",
    "YFCC",
    "SBU",
] 

# --- C. Quantization Precision (WxAy) ---
# Defines the bit-width combinations to test.
# Format: (Weight Bits, Activation Bits)
# Supported Combinations:
# - (8, 8) : Standard integer quantization.
# - (6, 8) : Weight compression, standard activations.
# - (6, 6) : Aggressive compression (requires QAT/LoRA/IGQ).
# - (4, 8) : Extreme compression (experimental).
BIT_WIDTHS_TO_TEST = [
    (8, 8), 
    (6, 8),
    (4, 8),
]

# --- D. Text Encoder Quantization ---
# Controls the scope of quantization.
# - False : Quantize Visual Encoder ONLY (Keep Text Encoder FP32).
# - True  : Quantize BOTH Visual and Text Encoders.

# NEW LOGIC: DYNAMIC SELECTION via Environment Variable
# Options: "VISUAL_ONLY", "VISUAL_TEXT", or "ALL" (default)
_quant_scope = os.environ.get("TARGET_QUANT_SCOPE", "ALL")

if _quant_scope == "VISUAL_ONLY":
    TEXT_QUANTIZATION_MODES = [False]
elif _quant_scope == "VISUAL_TEXT":
    TEXT_QUANTIZATION_MODES = [True]
else:
    # Default behavior: Run both loops in one job
    TEXT_QUANTIZATION_MODES = [False, True]

# --- E. Method Selection ---
# 1. Post-Training Quantization (PTQ) - Calibration only, no training.
# - "Simple PTQ"      : MinMax quantization.
# - "SmoothQuant PTQ" : Activation smoothing (migrates difficulty to weights).
# - "IGQ-ViT PTQ"     : Instance-Aware Group Quantization (Optimized for ViT).
ACTIVE_PTQ_METHODS = [
    "Simple PTQ",
    "SmoothQuant PTQ",
    "IGQ-ViT PTQ",
    "QwT PTQ",            # NEW: Quantization without Tears (Bias Correction)
    "APQ-ViT PTQ",        # NEW: Attention Power-of-Two Quantization
    "Rotation PTQ",       # NEW: SpinQuant/RoLoRA style rotation
    "OutlierAware PTQ",
    "Q-VLM",

]

# 2. Standard Quantization-Aware Training (QAT) - Fine-tuning fake-quantized weights.
# - "QAT" : Standard fake-quantization with Straight-Through Estimator (STE).
# - "LSQ" : Learned Step Size Quantization (Trains the scaling factors).
ACTIVE_STANDARD_QAT_METHODS = [
    "QAT",
    "LSQ",
    "Rotation + LSQ",


]

# 3. Fixed/Specialized QAT - specialized architectures or freezing strategies.
# - "QAT-LoRA" : Low-Rank Adaptation QAT (Trains low-rank adapters, freezes main weights).
# - "Q-ViT"    : Information Rectification & Distillation (Requires teacher, optimized for ViT).
ACTIVE_FIXED_QAT_METHODS = [
    "QAT-LoRA",
    "Q-ViT",
]

# --- F. Distillation Settings ---
# Defines the loss landscape for QAT methods.
# - "Contrastive Only": Standard CLIP loss (Image <-> Text).
# - "Hybrid":           CLIP loss + MSE(Student_Image_Feat, Teacher_Image_Feat).
# - "Distill Only":     Pure MSE feature mimicry (ignores text/labels).
ACTIVE_DISTILLATION_MODES = {
    " (Contrastive Only)": {"main_loss_weight": 1.0, "distill_weight": 0.0},
    " (Hybrid)": {"main_loss_weight": 0.5, "distill_weight": 0.5},
}

# ==============================================================================
#  SECTION 2: SYSTEM CONFIGURATION & PATHS
# ==============================================================================

# --- Hardware & Seeds ---
TARGET_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RANDOM_SEED = 42
NUM_RUNS = 3 

# ==============================================================================
#  MODEL ZOO REGISTRY
# ==============================================================================
# Dictionary of all supported model configurations
MODEL_ZOO = {

    # --- ALIGN & SIGLIP (Swapped to BASE variants) ---
    # 1. ALIGN Base: Uses xlm-roberta-base (instead of large) and ViT-B-32
    "ALIGN_RoBERTa_ViT-B":     {'type': 'open_clip', 'arch': 'xlm-roberta-base-ViT-B-32', 'data': 'laion5b_s13b_b90k'},

    # 2. SigLIP Base: Standard B-16
    "SigLIP_ViT-B-16":         {'type': 'open_clip', 'arch': 'ViT-B-16-SigLIP', 'data': 'webli'},
    # --- LARGE MODELS (Now the heaviest) ---
    "CLIP_ViT-L-14_Laion2B":   {'type': 'open_clip', 'arch': 'ViT-L-14', 'data': 'laion2b_s32b_b82k'},
    "CLIP_ViT-L-14_OpenAI":    {'type': 'open_clip', 'arch': 'ViT-L-14-quickgelu', 'data': 'openai'},
    
    # --- MEDIUM MODELS ---
    "EVA02_B-16":              {'type': 'open_clip', 'arch': 'EVA02-B-16', 'data': 'merged2b_s8b_b131k'},
    "ConvNeXt_Base":           {'type': 'open_clip', 'arch': 'convnext_base', 'data': 'laion400m_s13b_b51k'},

    # --- STANDARD BASE MODELS ---
    "CLIP_ViT-B-32_Laion2B":   {'type': 'open_clip', 'arch': 'ViT-B-32', 'data': 'laion2b_s34b_b79k'},
    "DFN_ViT-B-32":            {'type': 'open_clip', 'arch': 'ViT-B-32', 'data': 'datacomp_xl_s13b_b90k'},
    "CLIP_ViT-B-32_OpenAI":    {'type': 'open_clip', 'arch': 'ViT-B-32-quickgelu', 'data': 'openai'},
    "CoCa_ViT-B-32":           {'type': 'open_clip', 'arch': 'coca_ViT-B-32', 'data': 'laion2b_s13b_b90k'},
}
# DYNAMIC SELECTION via Environment Variable
# Default to first model if variable not set (for local testing)
_target_key = os.environ.get("TARGET_MODEL_KEY", "CLIP_ViT-B-32_OpenAI")

if _target_key not in MODEL_ZOO:
    raise ValueError(f"Model Key '{_target_key}' not found in MODEL_ZOO registry.")

MODEL_CONFIG = MODEL_ZOO[_target_key]

# Compatibility for existing scripts
CLIP_MODEL_ARCHITECTURE = MODEL_CONFIG['arch']
CLIP_MODEL_PRETRAINED_DATASET = MODEL_CONFIG['data']

# --- Paths ---
try:
    PROJECT_ROOT = Path(os.getcwd())
except NameError:
    PROJECT_ROOT = Path('.')

DATA_ROOT = PROJECT_ROOT / "quantization/data_quantization/"

# WebDataset Shards for Proxy Experiments
SHARD_PATH_CC3M = DATA_ROOT / "cc3m_shards_00000.tar"
SHARD_PATH_YFCC = DATA_ROOT / "yfcc_shard_10k.tar"
SHARD_PATH_SBU = DATA_ROOT / "sbu_shard_10k.tar"

EVAL_DATASET_SUBSAMPLE_RATIO = 1.0

DATASETS_ROOT = PROJECT_ROOT / "datasets"
IMAGENET_VAL_PATH = os.path.join(DATASETS_ROOT, "imagenet_1k_val_reorg") 
# Supported Evaluation Datasets (Keys match datasets_classes.py)
DATASET_PATHS = {
    # Standard Classification
    "cifar10":          os.path.join(DATASETS_ROOT, "cifar10"),
    "cifar100":         os.path.join(DATASETS_ROOT, "cifar100"),
    "imagenet1kval":    os.path.join(DATASETS_ROOT, "imagenet_1k_val_reorg"),
    
    # Textures & Scenes
    "dtd":              os.path.join(DATASETS_ROOT, "dtd"),
    "places365":        os.path.join(DATASETS_ROOT, "places365_standard/val"),
    "sun397":           os.path.join(DATASETS_ROOT, "sun397/formatted_test"),
    
    # Out-of-Distribution / Robustness (Requires specific folder structures)
    "imagenet-a":       os.path.join(DATASETS_ROOT, "imagenet-a"),      # Adversarial
    "imagenet-r":       os.path.join(DATASETS_ROOT, "imagenet-r"),      # Rendition
    "imagenet-v2":      os.path.join(DATASETS_ROOT, "imagenet-v2"),     # New Test Set
    "inaturalist":      os.path.join(DATASETS_ROOT, "iNaturalist_OOD_test_10k_v2"),

    # New Subsets (They all look at the main ImageNet Val folder)
    "imagenet10":       IMAGENET_VAL_PATH,
    "imagenet20":       IMAGENET_VAL_PATH,
    "imagenet500id":      IMAGENET_VAL_PATH,
}

# Select which datasets to run evaluation loop on
EVAL_DATASETS_TO_TEST = [
    "cifar10",
    "cifar100", 
    "imagenet1kval", 
    "imagenet10",
    "imagenet20",
    "imagenet500id",
    "sun397",

     
]

PROXY_DATASETS = {
    "CC3M": (SHARD_PATH_CC3M, 1000),
    "YFCC": (SHARD_PATH_YFCC, 1000),
    "DataComp": (SHARD_PATH_DATACOMP, 1000),
    "SBU": (SHARD_PATH_SBU, 1000),
}
# ID/OOD Split Files
ID_DATA_PATH = DATASETS_ROOT / "id_data.txt"
OOD_DATA_PATH = DATASETS_ROOT / "ood_data.txt"
# ==============================================================================
#  SECTION 3: HYPERPARAMETERS
# ==============================================================================

BATCH_SIZE = 100
EVAL_BATCH_SIZE = 512
# Sample Counts for Experiments
NUM_CALIBRATION_SAMPLES_REAL = 128   # For PTQ calibration (Real Data)
NUM_TRAINING_SAMPLES_REAL = 1000     # For QAT fine-tuning (Real Data)
NUM_CALIBRATION_SAMPLES_NOISE = 512  # For PTQ calibration (Synthetic)
NUM_TRAINING_SAMPLES_NOISE = 1024    # For QAT fine-tuning (Synthetic)

# Training Hyperparams
LEARNING_RATE = 1e-6       # Main weights / LoRA
LSQ_LEARNING_RATE = 1e-4   # Scaling factors (LSQ only)
QAT_TRAINING_STEPS = 10000   # Steps for QAT loop

# Logit Scale Tuning (Phase 2)
# Runs after quantization to fix Expected Calibration Error (ECE)
LOGIT_TUNE_STEPS = 50
LOGIT_TUNE_LR = 1e-2
GRAD_ACCUM_STEPS = 1

# Prompt Template for Synthetic/Noise calibration
CIFAR100_TEMPLATE = "a photo of a {}."

# ==============================================================================
#  SECTION 4: METHOD-SPECIFIC ARGUMENTS
# ==============================================================================

SIMPLE_PTQ_KWARGS = {} 

QAT_KWARGS = {
    'learning_rate': LEARNING_RATE
}

LSQ_KWARGS = {
    'learning_rate': LEARNING_RATE, 
    'lsq_learning_rate': LSQ_LEARNING_RATE
}

# IGQ-ViT: Instance-Aware Group Quantization
# Groups channels to reduce quantization error in LayerNorm/GELU heavy architectures.
IGQ_KWARGS = {
    'num_groups': 8 
}
# NEW: Configuration for Cosine QAT
COS_QAT_KWARGS = {
    'learning_rate': 5e-6,    
    'total_steps': 100,       
    'contrastive_weight': 0, 
    'warmup_pct': 1         # First 50% steps are PURE Cosine Sim, then Hybrid
}
# Q-ViT: Information Rectification Module (IRM) & DGD Distillation
QVIT_KWARGS = {
    'learning_rate': LEARNING_RATE
}

# QAT-LoRA: Low-Rank Adaptation
# Freezes weights, trains A/B matrices.
QAT_LORA_KWARGS = {
    'learning_rate': LEARNING_RATE, 
    'lora_rank': 4,    
    'lora_alpha': 4   
}


# OutlierAware (RegCache simulation)
OUTLIER_AWARE_KWARGS = {
    'outlier_percentile': 0.01  # Keep top 1% activations in FP16
}

# APQ-ViT: Uses Power-of-Two for Attention Scores
APQ_KWARGS = {
    'attn_method': 'pot' 
}

QWT_KWARGS = {
    'num_calibration_batches': 32,
    'method': 'qwt'
}

# NEW: Q-VLM (Visual Encoder Optimization)
# Optimizes the visual encoder using an entropy proxy before quantization
QVLM_KWARGS = {
    'epochs': 1,                # VEO is usually fast, 1-5 epochs on calibration set
    'learning_rate': 1e-6,      # Low LR to prevent destroying pre-trained features
    'distill_weight': 1.0,      # Maintain original semantics (L_err in paper)
    'quant_error_weight': 0.5,  # Minimize quantization noise (L_ent proxy)
    'method': 'qvlm'
}




ENABLE_ADAPTATION_BENCHMARK = True

# The datasets to adapt on (ID). 
# Constraint: The OOD dataset will be the one NOT selected here (between cifar10/imagenet1kval).
ADAPTATION_TARGETS = [
    "cifar100", 
    "imagenet1kval"
]

# Adaptation Hyperparameters
ADAPTATION_CONFIG = {
    'batch_size': 100,
    'steps': 10000,             # Number of batches to train (prevent infinite ImageNet loop)
    
    # Full Fine-Tuning Params
    'full_ft_lr': 1e-6,        # Very low to preserve pre-trained knowledge
    
    # LoRA Params
    'lora_lr': 1e-4,           # Standard LoRA LR
    'lora_rank': 16,
    'lora_alpha': 16,
    
    'save_checkpoints': False
}

# Which Quantization methods to test on the Adapted Models
ADAPTATION_QUANT_METHODS = [
    "Simple PTQ",
    "SmoothQuant PTQ",
    "Rotation PTQ",
    "QAT",
    "LSQ",
    "Rotation + LSQ"
]