import os
# Force offline mode for all library calls
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

# Memory fragmentation fix
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import logging
import copy
import gc
import datetime
import csv
import json
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm 

import scripts.config as config
import scripts.data_setup as data_setup
import scripts.ood_evaluation as ood_eval 
import scripts.datasets_classes as ds_classes 

from quantization.apply import (
    apply_simple_ptq, apply_smoothquant, apply_igq_vit, apply_qvit,
    apply_quantization_aware_training, apply_learned_step_size_quantization, 
    apply_qat_lora, apply_cosine_qat,
    apply_qwt_ptq, apply_apq_ptq, apply_rotation_ptq, apply_outlier_aware_ptq,
    apply_rotation_lsq, apply_qvlm_ptq 
)

logging.basicConfig(level=logging.INFO, format='INFO: %(message)s')

# --- REGISTRIES ---
ALL_PTQ_METHODS = {
    "Simple PTQ": (apply_simple_ptq, config.SIMPLE_PTQ_KWARGS),
    "SmoothQuant PTQ": (apply_smoothquant, {}),
    "IGQ-ViT PTQ": (apply_igq_vit, config.IGQ_KWARGS),
    "QwT PTQ": (apply_qwt_ptq, config.QWT_KWARGS),                 
    "APQ-ViT PTQ": (apply_apq_ptq, config.APQ_KWARGS),             
    "Rotation PTQ": (apply_rotation_ptq, {}),                      
    "OutlierAware PTQ": (apply_outlier_aware_ptq, config.OUTLIER_AWARE_KWARGS), 
    "Q-VLM": (apply_qvlm_ptq, config.QVLM_KWARGS)
}

ALL_STANDARD_QAT_METHODS = {
    "QAT": (apply_quantization_aware_training, config.QAT_KWARGS),
    "LSQ": (apply_learned_step_size_quantization, config.LSQ_KWARGS),
    "Rotation + LSQ": (apply_rotation_lsq, config.LSQ_KWARGS), 
    "CosQAT": (apply_cosine_qat, config.COS_QAT_KWARGS),
}

ALL_FIXED_QAT_METHODS = {
    "QAT-LoRA": (apply_qat_lora, config.QAT_LORA_KWARGS),
    "Q-ViT": (apply_qvit, config.QVIT_KWARGS),
}

# --- CACHING SETUP ---
MINING_CACHE_DIR = "ood_mining_cache" 
os.makedirs(MINING_CACHE_DIR, exist_ok=True)
TEXT_FEAT_CACHE_DIR = "ood_text_feat_cache" 
os.makedirs(TEXT_FEAT_CACHE_DIR, exist_ok=True)

# --- LOGGING SETUP ---
OOD_CSV_FIELDNAMES = [
    "Model_Key", "Architecture", "Quant_Scope",
    "ID_Dataset", "OOD_Dataset", 
    "Scenario", "Method", "Run_ID", "Seed", 
    "W_Bits", "A_Bits", "Text_Quantized", "Quant_Config_Str", 
    "ID_Acc", 
    "OOD_Scoring_Method", "AUROC", "FPR95", 
    "ID_Mean", "ID_Std", "OOD_Mean", "OOD_Std",
    "Status"
]

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

def init_csv_logging(model_key, quant_scope):
    output_dir = "results_ood"
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    job_id = os.environ.get("SLURM_JOB_ID", f"PID{os.getpid()}")
    safe_key = model_key.replace("/", "-")
    
    filename = os.path.join(output_dir, f"ood_{safe_key}_{quant_scope}_{timestamp}_{job_id}.csv")
    
    with open(filename, mode='w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=OOD_CSV_FIELDNAMES)
        writer.writeheader()
    
    print(f"Logging initialized: {filename}")
    return filename

def log_result(filename, row_data):
    try:
        with open(filename, mode='a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=OOD_CSV_FIELDNAMES)
            writer.writerow(row_data)
    except Exception as e:
        print(f"FAILED TO LOG ROW TO CSV ({filename}): {e}")

# --- CACHING HELPERS ---

def get_mined_negatives_cached(model, tokenizer, id_dataset_key, id_classes, device):
    if not any(m in config.OOD_METHODS for m in ["NegLabel", "EOE"]):
        return []
        
    safe_model_key = config._target_key.replace(" ", "_").replace("/", "-")
    cache_file = os.path.join(MINING_CACHE_DIR, f"{safe_model_key}_{id_dataset_key}_negatives.json")

    try:
        with open(cache_file, 'r') as f:
            mined_words = json.load(f)
        return mined_words
    except FileNotFoundError:
        if any(m in config.OOD_METHODS for m in ["NegLabel", "EOE"]):
            print(f"FATAL: Mining cache file not found for {id_dataset_key} (Expected: {cache_file}).")
            raise RuntimeError(
                "Missing mining cache. Run 'mine_negatives_utility.py' externally to create cache files."
            )
        return []
    except Exception as e:
        print(f"ERROR: Failed to load mining cache: {e}.")
        raise RuntimeError("Corrupted mining cache found.")

def generate_and_save_text_features(model, tokenizer, id_classes, id_template, mined_negative_words, device, cache_key=None):
    # Ensure model is on the correct device for feature generation
    # If model is on CPU (e.g. model_fp32 storage), we might need to handle that, 
    # but usually this is called with eval_model (GPU) or model_fp32 during baseline (moved to GPU).
    
    with torch.no_grad(), torch.amp.autocast('cuda', enabled=config.USE_AMP):
        # 1. ID Features
        aggregated_feats = None
        for temp in config.CLIP_IMAGENET_TEMPLATES:
            texts = [temp.format(c) for c in id_classes]
            curr_feats = []
            for j in range(0, len(texts), 500):
                tk = tokenizer(texts[j:j+500]).to(device)
                f = F.normalize(model.encode_text(tk), dim=-1)
                curr_feats.append(f)
            feats = torch.cat(curr_feats, dim=0)
            aggregated_feats = feats if aggregated_feats is None else aggregated_feats + feats
        f_id = F.normalize(aggregated_feats / len(config.CLIP_IMAGENET_TEMPLATES), dim=-1)

        # 2. CLIPN-A Rejection
        f_neg = None
        if "CLIPN-A" in config.OOD_METHODS:
            tk_neg = tokenizer([config.CLIPN_A_PARAMS["rejection_prompt"]]).to(device)
            f_neg = F.normalize(model.encode_text(tk_neg), dim=-1)

        # 3. NegLabel/EOE Unknowns
        f_unk = None
        if mined_negative_words:
            agg_unk = None
            desc = "        [Encoding Neg/EOE]" if cache_key else "        [Encoding Neg/EOE] (QAT)"
            for temp in tqdm(config.CLIP_IMAGENET_TEMPLATES, desc=desc, leave=False):
                texts = [temp.format(w) for w in mined_negative_words]
                curr_unk = []
                for j in range(0, len(texts), 500):
                    tk = tokenizer(texts[j:j+500]).to(device)
                    f = F.normalize(model.encode_text(tk), dim=-1)
                    curr_unk.append(f)
                feats = torch.cat(curr_unk, dim=0)
                agg_unk = feats if agg_unk is None else agg_unk + feats
            f_unk = F.normalize(agg_unk / len(config.CLIP_IMAGENET_TEMPLATES), dim=-1)

    if cache_key:
        feats_to_save = {
            'f_id': f_id.cpu(),
            'f_neg': f_neg.cpu() if f_neg is not None else None,
            'f_unk': f_unk.cpu() if f_unk is not None else None,
        }
        cache_path = os.path.join(TEXT_FEAT_CACHE_DIR, cache_key)
        torch.save(feats_to_save, cache_path)
    
    return f_id.to(device), (f_neg.to(device) if f_neg is not None else None), (f_unk.to(device) if f_unk is not None else None)

def get_text_features_cached(model, tokenizer, id_dataset_key, id_classes, id_template, mined_negatives, sc_type, m_label, w_bit, a_bit, qt_bool, device):
    """
    Hybrid Caching Strategy:
    1. ID Features (f_id) & CLIPN Negatives (f_neg): 
       ALWAYS generated fresh using the CURRENT `model`. 
       This ensures Accuracy and standard OOD scores reflect the actual Quantized model's performance.
       
    2. Mined Negatives (f_unk): 
       ALWAYS loaded from the FP32 Cache to save massive compute time.
    """
    
    # 1. Generate ID and Generic Negative features using the CURRENT (possibly Quantized) model
    # We pass mined_negative_words=None to skip the heavy encoding here.
    f_id, f_neg, _ = generate_and_save_text_features(
        model, tokenizer, id_classes, id_template, 
        mined_negative_words=None, 
        device=device, cache_key=None
    )

    # 2. Handle the HEAVY Mined Negatives (f_unk)
    f_unk = None
    if mined_negatives:
        safe_model_key = config._target_key.replace(" ", "_").replace("/", "-")
        
        # We explicitly look for the FP32 cache, regardless of the current model's quantization state
        fp32_cache_name = f"{safe_model_key}_{id_dataset_key}_FP32_text.pt"
        fp32_cache_path = os.path.join(TEXT_FEAT_CACHE_DIR, fp32_cache_name)

        if os.path.exists(fp32_cache_path):
            # HIT: Load the pre-computed FP32 negatives
            try:
                cached_data = torch.load(fp32_cache_path, map_location=device)
                f_unk = cached_data['f_unk'].to(device) if cached_data['f_unk'] is not None else None
            except Exception as e:
                print(f"    [Warning] Failed to load FP32 cache for negatives: {e}. Re-encoding (Slow).")
                # Fallback: If load fails, we have to encode
                _, _, f_unk = generate_and_save_text_features(
                    model, tokenizer, id_classes, id_template, mined_negatives, device
                )
        else:
            # MISS: This happens during the Baseline FP32 run or if cache is deleted.
            print(f"    [Cache Miss] Generating Master FP32 Mined Negatives Cache...")
            _, _, f_unk = generate_and_save_text_features(
                model, tokenizer, id_classes, id_template, mined_negatives, device, 
                cache_key=fp32_cache_name # Save as the Master FP32 Cache
            )

    return f_id, f_neg, f_unk

LOADER_CACHE = {}

def get_cached_or_create_resources(base_loader_key, preprocess, model_fp32, tokenizer):
    cache_key = base_loader_key
    if cache_key not in LOADER_CACHE:
        loaders, id_test_loader, id_classes, id_template = data_setup.get_dataset_loaders(base_loader_key, preprocess, get_train=True)
        if id_test_loader is None: 
            return None, None, None, None, None, None
        id_cal_loader, id_train_loader = loaders
        
        # model_fp32 passed here is only used for cache filename lookup or fallback encoding
        # Since get_mined_negatives_cached only reads JSON, device matters less, but we pass config.TARGET_DEVICE for consistency
        mined_negatives = get_mined_negatives_cached(model_fp32, tokenizer, base_loader_key, id_classes, config.TARGET_DEVICE)
        
        LOADER_CACHE[cache_key] = {
            'id_test_loader': id_test_loader,
            'id_cal_loader': id_cal_loader,
            'id_train_loader': id_train_loader,
            'id_classes': id_classes,
            'id_template': id_template,
            'mined_negatives': mined_negatives
        }

    cached = LOADER_CACHE[cache_key]
    return (
        cached['id_test_loader'], cached['id_cal_loader'], 
        cached['id_train_loader'], cached['id_classes'], 
        cached['id_template'], cached['mined_negatives']
    )


# --- MAIN ---
def main():
    model_key = config._target_key
    quant_scope = os.environ.get("TARGET_QUANT_SCOPE", "ALL")
    active_text_quant_modes = config.TEXT_QUANTIZATION_MODES
        
    print(f"\n{'='*80}\nSTARTING OOD BENCHMARK: {model_key} [{quant_scope}]\n{'='*80}")
    
    # --- GLOBAL INIT ---
    result_csv_path = init_csv_logging(model_key, quant_scope)
    data_setup.set_seed(config.RANDOM_SEED) 
    model_fp32, tokenizer, preprocess = data_setup.get_model_and_tokenizer()
    
    # OPTIMIZATION: Park the master FP32 model on CPU immediately.
    # This prevents it from occupying ~2-5GB of VRAM on cuda:0 throughout the script.
    # We will only move it to GPU when specifically needed (Baseline evaluation).
    model_fp32.eval()
    model_fp32.to('cpu')
    print("Master FP32 model moved to CPU to conserve GPU memory.")

    # FIX: Teacher Placement for 2 GPUs
    # Since we have 2 H200s, we try to put the teacher on cuda:1 to save cuda:0 memory.
    if config.ENABLE_TEACHER:
        print("Loading Teacher Model...")
        teacher_model = copy.deepcopy(model_fp32)
        for p in teacher_model.parameters(): p.requires_grad = False
        teacher_model.eval()
        
        if torch.cuda.device_count() > 1:
            teacher_device = 'cuda:1'
            print(f"Teacher moved to {teacher_device} (Secondary GPU) to save Main GPU VRAM.")
        else:
            teacher_device = 'cpu'
            print("Teacher moved to CPU to save VRAM (Single GPU detected).")
            
        teacher_model.to(teacher_device)
    else:
        teacher_model = None
    
    proxy_loaders = {}
    if config.ENABLE_PROXY_EXPERIMENTS:
        for p_name in config.ACTIVE_PROXY_DATASETS:
            if p_name in config.PROXY_DATASETS:
                path, _ = config.PROXY_DATASETS[p_name]
                loader_iterable = data_setup.create_train_iterable(str(path), preprocess, config.BATCH_SIZE)
                if loader_iterable:
                    proxy_loaders[p_name] = loader_iterable
                    print(f"Loaded proxy dataset: {p_name}")

    # --- DEFINE ALL BENCHMARK PAIRS TO RUN ---
    all_benchmark_pairs = []

    # 1. Determine which datasets are used as ID SOURCES
    id_source_keys = set()
    for id_name, ood_list in config.OOD_BENCHMARK_PAIRS.items():
        id_source_keys.add(id_name.split('_far_ood')[0])

    # 2. Cache resources ONLY for these identified ID sources
    print("\nCaching ID Resources (Text Features & Loaders) for defined ID Sources...")
    for base_id_key in id_source_keys:
        if base_id_key in config.DATASET_PATHS:
            get_cached_or_create_resources(base_id_key, preprocess, model_fp32, tokenizer)
        else:
            print(f"Warning: ID Source Key '{base_id_key}' found in OOD_BENCHMARK_PAIRS but not in DATASET_PATHS.")
            
    # 3. Build the full evaluation matrix (FIX: No tensors stored here)
    for id_name, ood_list in config.OOD_BENCHMARK_PAIRS.items():
        base_id_key = id_name.split('_far_ood')[0]
        resources = LOADER_CACHE.get(base_id_key)
        if not resources: continue 

        for ood_name in ood_list:
            ood_base_loader_key = ood_name.split('_far_ood')[0]
            _, ood_test_loader, _, _ = data_setup.get_dataset_loaders(ood_base_loader_key, preprocess, get_train=False) 
            if ood_test_loader is None: continue
            
            all_benchmark_pairs.append({
                'id_name': id_name, 'ood_name': ood_name, 
                'id_key': base_id_key, 
                'id_test_loader': resources['id_test_loader'], 
                'ood_test_loader': ood_test_loader,
                'id_classes': resources['id_classes'], 
                'id_template': resources['id_template'],
                'mined_negatives': resources['mined_negatives'],
                'id_cal_loader': resources['id_cal_loader'], 
                'id_train_loader': resources['id_train_loader']
            })

    # --- BASELINE PRE-CALCULATION & LOGGING ---
    print("\nCalculating and Logging FP32 Baseline results...")
    
    # Temporarily move FP32 model to GPU for Baseline Evaluation
    model_fp32.to(config.TARGET_DEVICE)
    
    for pair in all_benchmark_pairs:
        # Load features temporarily (generates FP32 cache if missing)
        f_id, f_neg, f_unk = get_text_features_cached(
            model_fp32, tokenizer, pair['id_key'], pair['id_classes'], pair['id_template'], 
            pair['mined_negatives'], 'fp32', 'FP32', 32, 32, False, config.TARGET_DEVICE
        )

        ood_results, id_acc = ood_eval.run_ood_benchmark_eval(
            model_fp32, pair['id_test_loader'], pair['ood_test_loader'], 
            f_id, f_neg, f_unk, config.TARGET_DEVICE
        )
        
        # Free memory immediately
        del f_id, f_neg, f_unk
        
        print(f"Baseline for {pair['id_name']} vs {pair['ood_name']} calculated (Acc: {id_acc:.4f}).")

        for run_idx in range(config.NUM_RUNS):
            current_seed = config.RANDOM_SEED + run_idx
            for ood_m, metrics in ood_results.items():
                log_result(result_csv_path, {
                    "Model_Key": model_key, "Architecture": config.CLIP_MODEL_ARCHITECTURE,
                    "Quant_Scope": "FP32_BASELINE",
                    "ID_Dataset": pair['id_name'], "OOD_Dataset": pair['ood_name'],
                    "Scenario": "FP32", 
                    "Method": "FP32", "Run_ID": run_idx, "Seed": current_seed,
                    "W_Bits": 32, "A_Bits": 32, "Text_Quantized": False, "Quant_Config_Str": "FP32",
                    "ID_Acc": id_acc,
                    "OOD_Scoring_Method": ood_m, 
                    "AUROC": metrics["AUROC"], "FPR95": metrics["FPR95"],
                    "ID_Mean": metrics["ID_Mean"], "ID_Std": metrics["ID_Std"],
                    "OOD_Mean": metrics["OOD_Mean"], "OOD_Std": metrics["OOD_Std"],
                    "Status": "Success"
                })
    
    # Move FP32 model back to CPU to free GPU memory for Quantization Loop
    model_fp32.to('cpu')
    cleanup()

    # --- DEFINE PROXY SCENARIOS ---
    quant_scenarios = []
    if config.ENABLE_PROXY_EXPERIMENTS:
        for p_name, p_loader in proxy_loaders.items():
            quant_scenarios.append({"name": p_name, "type": "ptq", "loader": p_loader, "cal": p_loader, "prompts": None, "proxy_source": p_name, "scenario_type": "proxy_ptq"})
            quant_scenarios.append({"name": p_name, "type": "qat", "loader": p_loader, "cal": p_loader, "prompts": None, "proxy_source": p_name, "scenario_type": "proxy_qat"})

    # --- QUANTIZATION EXPERIMENT LOOP ---
    print(f"\n{'='*80}\nSTARTING QUANTIZATION EXPERIMENTS\n{'='*80}")
    
    for run_idx in range(config.NUM_RUNS):
        current_seed = config.RANDOM_SEED + run_idx
        data_setup.set_seed(current_seed)
        
        print(f"\n--- RUN {run_idx+1}/{config.NUM_RUNS} (Seed: {current_seed}) ---")

        for qt_bool in active_text_quant_modes:
            run_quant_scope = "VISUAL_TEXT" if qt_bool else "VISUAL_ONLY"
            print(f"  [Scope] Text Quantized: {qt_bool} ({run_quant_scope})")
            
            # --- 1. RUN PROXY SCENARIOS ---
            active_proxy_ptq_scenarios = [sc for sc in quant_scenarios if sc.get('scenario_type') == 'proxy_ptq' and config.ACTIVE_PTQ_METHODS]
            active_proxy_qat_scenarios = [sc for sc in quant_scenarios if sc.get('scenario_type') == 'proxy_qat' and (config.ACTIVE_STANDARD_QAT_METHODS or config.ACTIVE_FIXED_QAT_METHODS)]
            proxy_scenarios_to_run = active_proxy_ptq_scenarios + active_proxy_qat_scenarios
            
            for sc in proxy_scenarios_to_run:
                sc_type = sc['type']
                method_queue = []

                if sc_type == 'ptq':
                    for m in config.ACTIVE_PTQ_METHODS: 
                        if m in ALL_PTQ_METHODS: method_queue.append((m, ALL_PTQ_METHODS[m], sc_type))
                elif sc_type == 'qat':
                    for m in config.ACTIVE_STANDARD_QAT_METHODS:
                        if m in ALL_STANDARD_QAT_METHODS:
                            if m == "CosQAT":
                                method_queue.append((m, ALL_STANDARD_QAT_METHODS[m], sc_type))
                            else:
                                for dist_name, dist_cfg in config.ACTIVE_DISTILLATION_MODES.items():
                                    cfg = dist_cfg.copy()
                                    if not config.ENABLE_TEACHER: cfg['distill_weight'] = 0.0
                                    method_queue.append((f"{m}{dist_name}", (ALL_STANDARD_QAT_METHODS[m][0], {**ALL_STANDARD_QAT_METHODS[m][1], **cfg}), sc_type))
                    for m in config.ACTIVE_FIXED_QAT_METHODS:
                        if m in ALL_FIXED_QAT_METHODS: method_queue.append((m, ALL_FIXED_QAT_METHODS[m], sc_type))

                for m_label, (apply_fn, base_kwargs), sc_type in method_queue:
                    for w_bit, a_bit in config.BIT_WIDTHS_TO_TEST:
                        # 1. Aggressive cleanup BEFORE starting
                        cleanup()
                        
                        config_str = f"W{w_bit}A{a_bit}_Txt{int(qt_bool)}"
                        print(f"    [Quantizing] {sc['name']} | {m_label} | {config_str}")
                        
                        # Initialize variables to None so 'finally' block handles them safely
                        temp_model = None
                        eval_model = None
                        f_id = f_neg = f_unk = None
                        
                        try:
                            # COPY FROM CPU MODEL (Saves GPU VRAM)
                            temp_model = copy.deepcopy(model_fp32)
                            
                            run_kwargs = {
                                'target_device': config.TARGET_DEVICE, 'tokenizer': tokenizer, 
                                'prompts': sc.get('prompts'), 'quantize_text': qt_bool, 
                                'weight_bits': w_bit, 'act_bits': a_bit, **base_kwargs
                            }
                            
                            if sc_type == 'ptq':
                                eval_model = apply_fn(temp_model, calibration_dataloader=sc['loader'], **run_kwargs)
                            elif sc_type == 'qat':
                                eval_model = apply_fn(temp_model, training_dataloader=sc['loader'], calibration_dataloader=sc['cal'], teacher=teacher_model, **run_kwargs)
                            
                            if eval_model is None: raise RuntimeError("Model returned None.")

                            # Evaluate Proxy Model on ALL pairs
                            for pair in all_benchmark_pairs:
                                # Use Hybrid Text Features: f_id from current model, f_unk from FP32 cache
                                f_id, f_neg, f_unk = get_text_features_cached(
                                    eval_model, tokenizer, pair['id_key'], pair['id_classes'], pair['id_template'], 
                                    pair['mined_negatives'], sc_type, m_label, w_bit, a_bit, qt_bool, config.TARGET_DEVICE
                                )
                                
                                ood_results, id_acc = ood_eval.run_ood_benchmark_eval(
                                    eval_model, pair['id_test_loader'], pair['ood_test_loader'], 
                                    f_id, f_neg, f_unk, config.TARGET_DEVICE
                                )

                                for ood_m, metrics in ood_results.items():
                                    log_result(result_csv_path, {
                                        "Model_Key": model_key, "Architecture": config.CLIP_MODEL_ARCHITECTURE,
                                        "Quant_Scope": run_quant_scope,
                                        "ID_Dataset": pair['id_name'], "OOD_Dataset": pair['ood_name'],
                                        "Scenario": sc['name'], 
                                        "Method": m_label, "Run_ID": run_idx, "Seed": current_seed,
                                        "W_Bits": w_bit, "A_Bits": a_bit, "Text_Quantized": qt_bool, "Quant_Config_Str": config_str,
                                        "ID_Acc": id_acc, "OOD_Scoring_Method": ood_m, 
                                        "AUROC": metrics["AUROC"], "FPR95": metrics["FPR95"],
                                        "ID_Mean": metrics["ID_Mean"], "ID_Std": metrics["ID_Std"],
                                        "OOD_Mean": metrics["OOD_Mean"], "OOD_Std": metrics["OOD_Std"],
                                        "Status": "Success"
                                    })
                                
                                # 2. Explicitly delete tensor features after use per pair
                                del f_id, f_neg, f_unk
                                f_id, f_neg, f_unk = None, None, None

                        except Exception as e:
                            print(f"    [Error] {str(e)}")
                            
                        finally:
                            # 3. ROBUST CLEANUP
                            # Delete in reverse order of creation/dependency
                            if 'f_id' in locals() and f_id is not None: del f_id
                            if 'f_neg' in locals() and f_neg is not None: del f_neg
                            if 'f_unk' in locals() and f_unk is not None: del f_unk
                            if 'eval_model' in locals() and eval_model is not None: del eval_model
                            if 'temp_model' in locals() and temp_model is not None: del temp_model
                            
                            # Force pointers to None
                            f_id = f_neg = f_unk = None
                            eval_model = None
                            temp_model = None
                            
                            # Now GC and Empty Cache will actually work
                            cleanup()

            # --- 2. RUN REALID SCENARIOS ---
            if config.ENABLE_REAL_ID_EXPERIMENTS:
                for base_id_key, resources in LOADER_CACHE.items():
                    if resources['id_cal_loader'] is None: continue

                    real_id_scenarios = []
                    if resources['id_cal_loader']:
                        # FIX: Append ID Name to Scenario
                        real_id_scenarios.append({
                            "name": f"RealID_PTQ_{base_id_key}", "type": "ptq", "loader": resources['id_cal_loader'], 
                            "prompts": resources['id_classes'], "cal": resources['id_cal_loader'],
                            "id_key_match": base_id_key, "scenario_type": "realid_ptq"
                        })
                    
                    if resources['id_train_loader']:
                        real_id_scenarios.append({
                            "name": f"RealID_QAT_{base_id_key}", "type": "qat", "loader": resources['id_train_loader'], 
                            "cal": resources['id_cal_loader'], "prompts": resources['id_classes'],
                            "id_key_match": base_id_key, "scenario_type": "realid_qat"
                        })

                    for sc in real_id_scenarios:
                        sc_type = sc['type']
                        method_queue = []
                        if sc_type == 'ptq':
                            for m in config.ACTIVE_PTQ_METHODS: 
                                if m in ALL_PTQ_METHODS: method_queue.append((m, ALL_PTQ_METHODS[m], sc_type))
                        elif sc_type == 'qat':
                            for m in config.ACTIVE_STANDARD_QAT_METHODS:
                                if m in ALL_STANDARD_QAT_METHODS:
                                    if m == "CosQAT":
                                        method_queue.append((m, ALL_STANDARD_QAT_METHODS[m], sc_type))
                                    else:
                                        for dist_name, dist_cfg in config.ACTIVE_DISTILLATION_MODES.items():
                                            cfg = dist_cfg.copy()
                                            if not config.ENABLE_TEACHER: cfg['distill_weight'] = 0.0
                                            method_queue.append((f"{m}{dist_name}", (ALL_STANDARD_QAT_METHODS[m][0], {**ALL_STANDARD_QAT_METHODS[m][1], **cfg}), sc_type))
                            for m in config.ACTIVE_FIXED_QAT_METHODS:
                                if m in ALL_FIXED_QAT_METHODS: method_queue.append((m, ALL_FIXED_QAT_METHODS[m], sc_type))

                        for m_label, (apply_fn, base_kwargs), sc_type in method_queue:
                            for w_bit, a_bit in config.BIT_WIDTHS_TO_TEST:
                                # 1. Aggressive cleanup BEFORE starting
                                cleanup()
                                
                                config_str = f"W{w_bit}A{a_bit}_Txt{int(qt_bool)}"
                                print(f"    [Quantizing] {sc['name']} | {m_label} | {config_str}")
                                
                                # Initialize variables to None so 'finally' block handles them safely
                                temp_model = None
                                eval_model = None
                                f_id = f_neg = f_unk = None
                                
                                try:
                                    # COPY FROM CPU MODEL
                                    temp_model = copy.deepcopy(model_fp32)
                                    
                                    run_kwargs = {
                                        'target_device': config.TARGET_DEVICE, 'tokenizer': tokenizer, 
                                        'prompts': sc.get('prompts'), 'quantize_text': qt_bool, 
                                        'weight_bits': w_bit, 'act_bits': a_bit, **base_kwargs
                                    }
                                    
                                    if sc_type == 'ptq':
                                        eval_model = apply_fn(temp_model, calibration_dataloader=sc['loader'], **run_kwargs)
                                    elif sc_type == 'qat':
                                        eval_model = apply_fn(temp_model, training_dataloader=sc['loader'], calibration_dataloader=sc['cal'], teacher=teacher_model, **run_kwargs)

                                    if eval_model is None: raise RuntimeError("Model returned None.")

                                    # Evaluate ONLY on matching ID pairs
                                    for pair in all_benchmark_pairs:
                                        if pair['id_key'] != base_id_key: continue
                                        
                                        f_id, f_neg, f_unk = get_text_features_cached(
                                            eval_model, tokenizer, pair['id_key'], pair['id_classes'], pair['id_template'], 
                                            pair['mined_negatives'], sc_type, m_label, w_bit, a_bit, qt_bool, config.TARGET_DEVICE
                                        )
                                        
                                        ood_results, id_acc = ood_eval.run_ood_benchmark_eval(
                                            eval_model, pair['id_test_loader'], pair['ood_test_loader'], 
                                            f_id, f_neg, f_unk, config.TARGET_DEVICE
                                        )

                                        for ood_m, metrics in ood_results.items():
                                            log_result(result_csv_path, {
                                                "Model_Key": model_key, "Architecture": config.CLIP_MODEL_ARCHITECTURE,
                                                "Quant_Scope": run_quant_scope,
                                                "ID_Dataset": pair['id_name'], "OOD_Dataset": pair['ood_name'],
                                                "Scenario": sc['name'], 
                                                "Method": m_label, "Run_ID": run_idx, "Seed": current_seed,
                                                "W_Bits": w_bit, "A_Bits": a_bit, "Text_Quantized": qt_bool, "Quant_Config_Str": config_str,
                                                "ID_Acc": id_acc, "OOD_Scoring_Method": ood_m, 
                                                "AUROC": metrics["AUROC"], "FPR95": metrics["FPR95"],
                                                "ID_Mean": metrics["ID_Mean"], "ID_Std": metrics["ID_Std"],
                                                "OOD_Mean": metrics["OOD_Mean"], "OOD_Std": metrics["OOD_Std"],
                                                "Status": "Success"
                                            })

                                        # 2. Explicitly delete tensor features after use per pair
                                        del f_id, f_neg, f_unk
                                        f_id, f_neg, f_unk = None, None, None
                                    
                                except Exception as e:
                                    print(f"    [Error] {str(e)}")
                                    
                                finally:
                                    # 3. ROBUST CLEANUP
                                    # Delete in reverse order of creation/dependency
                                    if 'f_id' in locals() and f_id is not None: del f_id
                                    if 'f_neg' in locals() and f_neg is not None: del f_neg
                                    if 'f_unk' in locals() and f_unk is not None: del f_unk
                                    if 'eval_model' in locals() and eval_model is not None: del eval_model
                                    if 'temp_model' in locals() and temp_model is not None: del temp_model
                                    
                                    # Force pointers to None
                                    f_id = f_neg = f_unk = None
                                    eval_model = None
                                    temp_model = None
                                    
                                    # Now GC and Empty Cache will actually work
                                    cleanup()

if __name__ == "__main__":
    main()