# run_experiments.py
import os
import ctypes

# --- MEMORY MANAGEMENT ENV VARS (Must be set before importing torch) ---
# expandable_segments: Helps prevent fragmentation crashes on large cards like H200
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

import logging
import copy
import gc
import sys
import datetime
import csv
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader

# --- INTERNAL MODULE IMPORTS ---
import scripts.config as config
import scripts.data_setup as data_setup
import scripts.evaluation as evaluation

from quantization.apply import (
    apply_simple_ptq, apply_smoothquant, apply_igq_vit, apply_qvit,
    apply_quantization_aware_training, apply_learned_step_size_quantization, 
    apply_qat_lora, apply_cosine_qat,
    apply_qwt_ptq, apply_apq_ptq, apply_rotation_ptq, apply_outlier_aware_ptq,
    apply_rotation_lsq, apply_qvlm_ptq 
)

logging.basicConfig(level=logging.INFO, format='INFO: %(message)s')

# --- HARDWARE SETUP FOR DUAL H200 ---
# ACTIVE_DEVICE: Where Training/Quantization happens (GPU 0)
# STORAGE_DEVICE: Where the frozen Teacher/Ref models sit (GPU 1)
ACTIVE_DEVICE = "cuda:0"
if torch.cuda.device_count() > 1:
    STORAGE_DEVICE = "cuda:1"
    print(f"INFO: Dual-GPU Logic Enabled. Active: {ACTIVE_DEVICE}, Storage: {STORAGE_DEVICE}")
else:
    STORAGE_DEVICE = "cuda:0"
    print(f"WARNING: Only 1 GPU detected. Memory pressure will be higher.")

# Update Config to point to Active Device for operations
config.TARGET_DEVICE = ACTIVE_DEVICE

# --- METHOD REGISTRIES ---
ALL_PTQ_METHODS = {
    "Simple PTQ": (apply_simple_ptq, config.SIMPLE_PTQ_KWARGS),
    "SmoothQuant PTQ": (apply_smoothquant, {}),
    "IGQ-ViT PTQ": (apply_igq_vit, config.IGQ_KWARGS),
    "QwT PTQ": (apply_qwt_ptq, config.QWT_KWARGS),                 
    "APQ-ViT PTQ": (apply_apq_ptq, config.APQ_KWARGS),             
    "Rotation PTQ": (apply_rotation_ptq, {}),                      
    "OutlierAware PTQ": (apply_outlier_aware_ptq, config.OUTLIER_AWARE_KWARGS), 
    "Q-VLM": (apply_qvlm_ptq, config.QVLM_KWARGS)
}

ALL_STANDARD_QAT_METHODS = {
    "QAT": (apply_quantization_aware_training, config.QAT_KWARGS),
    "LSQ": (apply_learned_step_size_quantization, config.LSQ_KWARGS),
    "Rotation + LSQ": (apply_rotation_lsq, config.LSQ_KWARGS), 
    "CosQAT": (apply_cosine_qat, config.COS_QAT_KWARGS),
}

ALL_FIXED_QAT_METHODS = {
    "QAT-LoRA": (apply_qat_lora, config.QAT_LORA_KWARGS),
    "Q-ViT": (apply_qvit, config.QVIT_KWARGS),
}

CSV_FIELDNAMES = [
    "Model_Key", "Architecture", "Dataset_Path", 
    "Scenario", "Method", "Run_ID", "Seed", 
    "W_Bits", "A_Bits", "Text_Quantized", "Quant_Config_Str", 
    "Eval_Dataset", "Acc", 
    "ECE_Pre_Tuning", "ECE_Post_Tuning", 
    "Cosine_Sim", "Status"
]

def cleanup():
    """Aggressive memory cleanup to prevent VRAM fragmentation."""
    # 1. Python GC
    gc.collect()
    
    # 2. PyTorch Cache Clearing
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    
    # 3. System Libc Malloc Trim (Linux only) - Forces release of heap memory to OS
    try:
        libc = ctypes.CDLL("libc.so.6")
        libc.malloc_trim(0)
    except Exception:
        pass

def evaluate_all_datasets(model, eval_suite, tokenizer, qt_bool):
    """
    Runs evaluation. Handles device mismatch between Storage (Ref) and Active (Quant) models.
    """
    results_map = {}
    model.eval()

    for eval_name, eval_data in eval_suite.items():
        # Retrieve FP32 Reference (Lives on STORAGE_DEVICE)
        ref_model = eval_data['ref_model_fp32'] 
        
        # Retrieve Pre-computed features (Lives on CPU or STORAGE_DEVICE)
        if not qt_bool:
            # Move specific features to Active Device for dot product
            current_txt_feats = eval_data['baseline_feats'].to(ACTIVE_DEVICE)
        else:
            # Re-encode text if text encoder is quantized (On ACTIVE_DEVICE)
            ts = [eval_data['template'].format(c) for c in eval_data['class_names']]
            fs = []
            context = torch.amp.autocast('cuda') if config.USE_AMP else torch.no_grad()
            with torch.no_grad(), context:
                for i in range(0, len(ts), 256):
                    tk = tokenizer(ts[i:i+256]).to(ACTIVE_DEVICE)
                    f = model.encode_text(tk)
                    f = f / f.norm(dim=-1, keepdim=True)
                    fs.append(f)
            current_txt_feats = torch.cat(fs, dim=0)

        # Note: We pass ref_model (GPU 1) and model (GPU 0). 
        # The evaluation.run script must handle moving embeddings to same device for similarity calc.
        # If run_comprehensive_evaluation is strict, we might need a temporary wrapper.
        # Assuming run_comprehensive_evaluation calculates accuracy via cosine sim:
        
        try:
            metrics = evaluation.run_comprehensive_evaluation(
                model, 
                ref_model, # Pass reference on STORAGE device to save ACTIVE memory
                eval_data['loader'], 
                current_txt_feats, 
                ACTIVE_DEVICE
            )
        except RuntimeError as e:
            if "device" in str(e).lower():
                # Fallback: If eval code crashes due to device mismatch, 
                # we temporarily move ref_model output inside the eval loop (handled by eval script usually)
                # or we just accept we need to pass a dummy ref if we only care about Zero-Shot Acc
                print(f"WARNING: Device mismatch in eval. Ensuring tensors align.")
                metrics = evaluation.run_comprehensive_evaluation(
                    model, ref_model, eval_data['loader'], current_txt_feats, ACTIVE_DEVICE
                )
            else:
                raise e
        
        del current_txt_feats
        results_map[eval_name] = metrics
        
    return results_map

def init_csv_logging(model_key, quant_scope):
    output_dir = "results"
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    safe_model_key = model_key.replace(" ", "_").replace("/", "-")
    job_id = os.environ.get("SLURM_JOB_ID", f"PID{os.getpid()}")
    filename = os.path.join(output_dir, f"results_{safe_model_key}_{quant_scope}_{timestamp}_{job_id}.csv")
    
    with open(filename, mode='w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=CSV_FIELDNAMES)
        writer.writeheader()
    return filename

def log_result(filename, row_data):
    try:
        with open(filename, mode='a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=CSV_FIELDNAMES)
            writer.writerow(row_data)
    except Exception as e:
        print(f"FAILED TO LOG RESULT TO CSV: {e}")

def main():
    model_key = os.environ.get("TARGET_MODEL_KEY", "Unknown_Model")
    quant_scope = os.environ.get("TARGET_QUANT_SCOPE", "ALL")
    
    print(f"\n{'='*80}")
    print(f"STARTING HPC JOB | DUAL GPU CONFIGURATION")
    print(f"Active Node (Training): {ACTIVE_DEVICE}")
    print(f"Storage Node (Ref/Teacher): {STORAGE_DEVICE}")
    print(f"Model Key:      {model_key}")
    print(f"{'='*80}\n")

    result_csv_path = init_csv_logging(model_key, quant_scope)
    data_setup.set_seed(config.RANDOM_SEED)
    
    # 1. Load Master Models to STORAGE DEVICE (GPU 1)
    # This prevents them from ever eating into the Training VRAM (GPU 0)
    print("Loading Master FP32 Model to STORAGE DEVICE...")
    master_model_fp32, tokenizer, preprocess = data_setup.get_model_and_tokenizer()
    master_model_fp32.to(STORAGE_DEVICE)
    master_model_fp32.eval()

    if config.ENABLE_TEACHER:
        print("Loading Master Teacher Model to STORAGE DEVICE...")
        master_teacher_model = copy.deepcopy(master_model_fp32)
        for p in master_teacher_model.parameters(): p.requires_grad = False
        master_teacher_model.eval()
        # Already on STORAGE_DEVICE via copy
    else:
        master_teacher_model = None

    # 2. Pre-compute Baselines (Using Storage Device)
    eval_suite = {} 
    print(f"Pre-computing FP32 Baselines (Device: {STORAGE_DEVICE})...")
    
    for d_name in config.EVAL_DATASETS_TO_TEST:
        _, test_loader, class_names, template = data_setup.get_dataset_loaders(d_name, preprocess, get_train=False)
        if test_loader is None: continue
            
        texts = [template.format(c) for c in class_names]
        feats = []
        batch_size = 256
        
        # Compute on Storage Device
        with torch.no_grad():
            for i in range(0, len(texts), batch_size):
                chunk = texts[i : i + batch_size]
                tokenized = tokenizer(chunk).to(STORAGE_DEVICE)
                f = master_model_fp32.encode_text(tokenized)
                f = f / f.norm(dim=-1, keepdim=True)
                feats.append(f)
        
        baseline_feats = torch.cat(feats, dim=0).cpu() # Store on CPU
        
        # Run baseline eval
        # Pass Active Device for 'target_device' so the eval logic knows where to run ops if needed,
        # but inputs are on Storage.
        metrics = evaluation.run_comprehensive_evaluation(
            master_model_fp32, 
            master_model_fp32, 
            test_loader, 
            baseline_feats.to(STORAGE_DEVICE), 
            STORAGE_DEVICE 
        )
        
        log_result(result_csv_path, {
            "Model_Key": model_key, "Architecture": config.CLIP_MODEL_ARCHITECTURE,
            "Dataset_Path": config.CLIP_MODEL_PRETRAINED_DATASET, "Scenario": "Baseline", "Method": "FP32",
            "Run_ID": "N/A", "Seed": "N/A", "W_Bits": 32, "A_Bits": 32, "Text_Quantized": False, "Quant_Config_Str": "FP32",
            "Eval_Dataset": d_name, "Acc": metrics["Zero-Shot Accuracy"],
            "ECE_Pre_Tuning": metrics["ECE"], "ECE_Post_Tuning": metrics["ECE"],
            "Cosine_Sim": metrics.get("Avg. Cosine Similarity", 1.0), "Status": "Baseline"
        })

        eval_suite[d_name] = {
            "loader": test_loader, "baseline_feats": baseline_feats,
            "class_names": class_names, "template": template,
            "ref_model_fp32": master_model_fp32 # Pointer to GPU 1 object
        }
        cleanup()

    # 3. Define Scenarios
    scenarios = []
    if config.ENABLE_PROXY_EXPERIMENTS:
        for p_name in config.ACTIVE_PROXY_DATASETS:
            path, _ = config.PROXY_DATASETS[p_name]
            loader = data_setup.create_train_iterable(str(path), preprocess, config.BATCH_SIZE)
            if loader: scenarios.append({"name": f"Proxy_{p_name}", "cal_loader": loader, "train_loader": loader, "prompts": None})

    if config.ENABLE_REAL_ID_EXPERIMENTS:
        for d_name in config.EVAL_DATASETS_TO_TEST:
            loaders, _, class_names, _ = data_setup.get_dataset_loaders(d_name, preprocess, get_train=True)
            if loaders[0]: scenarios.append({"name": f"RealID_{d_name}", "cal_loader": loaders[0], "train_loader": loaders[1], "prompts": class_names})

    # 4. Main Experiment Loop
    for scenario in scenarios:
        sc_name = scenario['name']
        queue = []
        
        # Build Queue
        for m_name in config.ACTIVE_PTQ_METHODS:
            if m_name in ALL_PTQ_METHODS: queue.append((m_name, ALL_PTQ_METHODS[m_name], 'ptq'))
        
        if scenario['train_loader']:
            for m_name in config.ACTIVE_STANDARD_QAT_METHODS:
                if m_name in ALL_STANDARD_QAT_METHODS:
                    if m_name == "CosQAT":
                         queue.append((m_name, ALL_STANDARD_QAT_METHODS[m_name], 'qat'))
                    else:
                        for dist_name, dist_cfg in config.ACTIVE_DISTILLATION_MODES.items():
                             run_cfg = dist_cfg.copy()
                             if not config.ENABLE_TEACHER: run_cfg['distill_weight'] = 0.0
                             queue.append((f"{m_name}{dist_name}", (ALL_STANDARD_QAT_METHODS[m_name][0], {**ALL_STANDARD_QAT_METHODS[m_name][1], **run_cfg}), 'qat'))
            for m_name in config.ACTIVE_FIXED_QAT_METHODS:
                if m_name in ALL_FIXED_QAT_METHODS:
                    if m_name == "Q-ViT" and not config.ENABLE_TEACHER: continue
                    queue.append((m_name, ALL_FIXED_QAT_METHODS[m_name], 'qat'))

        for method_label, (apply_fn, base_kwargs), m_type in queue:
            for w_bit, a_bit in config.BIT_WIDTHS_TO_TEST:
                for qt_bool in config.TEXT_QUANTIZATION_MODES:
                    
                    config_str = f"W{w_bit}A{a_bit}_Txt{int(qt_bool)}"
                    print(f"\n--> {model_key} | {sc_name} | {method_label} | {config_str}")
                    
                    for run_idx in range(config.NUM_RUNS):
                        cleanup() # Start fresh
                        
                        current_seed = config.RANDOM_SEED + run_idx
                        data_setup.set_seed(current_seed)
                        
                        student_model = None
                        local_teacher = None
                        
                        try:
                            # 1. CLONE STUDENT: Storage (GPU1) -> CPU -> Active (GPU0)
                            # Passing through CPU ensures no graph connection remains.
                            student_model = copy.deepcopy(master_model_fp32).to('cpu').to(ACTIVE_DEVICE)
                            
                            # 2. CLONE TEACHER (For QAT only): Storage -> CPU -> Active
                            if config.ENABLE_TEACHER and m_type == 'qat':
                                local_teacher = copy.deepcopy(master_teacher_model).to('cpu').to(ACTIVE_DEVICE)
                                local_teacher.eval()
                            else:
                                local_teacher = None

                            run_kwargs = {
                                'target_device': ACTIVE_DEVICE, 'tokenizer': tokenizer, 'prompts': scenario['prompts'],
                                'quantize_text': qt_bool, 'weight_bits': w_bit, 'act_bits': a_bit, **base_kwargs
                            }
                            
                            # 3. Apply Method
                            # Explicitly pass the local_teacher on GPU0 so distillation doesn't cross devices/crash
                            if m_type == 'ptq':
                                student_model = apply_fn(student_model, calibration_dataloader=scenario['cal_loader'], **run_kwargs)
                            else:
                                student_model = apply_fn(
                                    student_model, 
                                    training_dataloader=scenario['train_loader'], 
                                    calibration_dataloader=scenario['cal_loader'], 
                                    teacher=local_teacher, 
                                    **run_kwargs
                                )

                            # 4. Wipe Gradients (Crucial for LSQ memory leaks)
                            if student_model:
                                student_model.zero_grad(set_to_none=True)
                                for param in student_model.parameters(): param.grad = None

                            # 5. Evaluate
                            metrics_pre = evaluate_all_datasets(student_model, eval_suite, tokenizer, qt_bool)

                            # 6. Logit Tuning
                            metrics_post = metrics_pre
                            if config.ENABLE_LOGIT_TUNING:
                                # Use local teacher if available (GPU0), else Master (GPU1 - careful with device mismatch in tuning)
                                # To be safe, use local_teacher. If local_teacher was None (PTQ), create a temp one on GPU0
                                tune_teacher = local_teacher
                                if tune_teacher is None and config.ENABLE_TEACHER:
                                    tune_teacher = copy.deepcopy(master_teacher_model).to(ACTIVE_DEVICE)
                                    tune_teacher.eval()
                                
                                student_model = evaluation.tune_logit_scale(
                                    student_model, scenario['cal_loader'], tokenizer, 
                                    scenario['prompts'], ACTIVE_DEVICE, teacher_model=tune_teacher
                                )
                                metrics_post = evaluate_all_datasets(student_model, eval_suite, tokenizer, qt_bool)
                                
                                # Clean up temp teacher immediately
                                if tune_teacher is not None: 
                                    del tune_teacher

                            # 7. Log
                            for eval_name in eval_suite.keys():
                                m_post = metrics_post[eval_name]
                                log_result(result_csv_path, {
                                    "Model_Key": model_key, "Architecture": config.CLIP_MODEL_ARCHITECTURE,
                                    "Dataset_Path": config.CLIP_MODEL_PRETRAINED_DATASET,
                                    "Scenario": sc_name, "Method": method_label, "Run_ID": run_idx, "Seed": current_seed,
                                    "W_Bits": w_bit, "A_Bits": a_bit, "Text_Quantized": qt_bool,
                                    "Quant_Config_Str": config_str, "Eval_Dataset": eval_name,
                                    "Acc": m_post["Zero-Shot Accuracy"], 
                                    "ECE_Pre_Tuning": metrics_pre[eval_name]["ECE"], 
                                    "ECE_Post_Tuning": m_post["ECE"],
                                    "Cosine_Sim": m_post.get("Avg. Cosine Similarity", 0.0), "Status": "Success"
                                })
                                
                        except RuntimeError as e:
                            err_msg = str(e)
                            if "out of memory" in err_msg:
                                print(f"OOM ERROR in Run {run_idx}: {err_msg}")
                                # Force clear immediately
                                if student_model: del student_model
                                if local_teacher: del local_teacher
                                cleanup()
                            else:
                                print(f"RUNTIME ERROR in Run {run_idx}: {err_msg}")
                            
                            log_result(result_csv_path, {
                                "Model_Key": model_key, "Architecture": config.CLIP_MODEL_ARCHITECTURE,
                                "Dataset_Path": config.CLIP_MODEL_PRETRAINED_DATASET,
                                "Scenario": sc_name, "Method": method_label, "Run_ID": run_idx, "Seed": current_seed,
                                "W_Bits": w_bit, "A_Bits": a_bit, "Text_Quantized": qt_bool, "Quant_Config_Str": config_str,
                                "Eval_Dataset": "ERROR", "Acc": 0.0, "ECE_Pre_Tuning": 0.0, "ECE_Post_Tuning": 0.0, "Cosine_Sim": 0.0,
                                "Status": f"Failed: {err_msg[:50]}"
                            })
                        except Exception as e:
                            print(f"GENERAL ERROR in Run {run_idx}: {str(e)}")
                            log_result(result_csv_path, {
                                "Status": f"Failed: {str(e)[:50]}"
                            })

                        # --- CRITICAL CLEANUP ---
                        # Delete references
                        if student_model: del student_model
                        if local_teacher: del local_teacher
                        
                        # Force Python to drop references
                        gc.collect() 
                        # Force PyTorch to release cached blocks
                        torch.cuda.empty_cache()
                        cleanup()

    print("JOB COMPLETED.")

if __name__ == "__main__":
    main()