
import os
# Force offline mode
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import logging
import copy
import gc
import datetime
import csv
import torch
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from tqdm import tqdm 
from imagecorruptions import corrupt
# --- STANDARD LIBRARY CHECK ---

import scripts.config as config
import scripts.data_setup as data_setup
from scripts.evaluation import calculate_calibration_metrics

from quantization.apply import (
    apply_simple_ptq, apply_smoothquant, apply_igq_vit, apply_qvit,
    apply_quantization_aware_training, apply_learned_step_size_quantization, 
    apply_qat_lora, apply_cosine_qat,
    apply_qwt_ptq, apply_apq_ptq, apply_rotation_ptq, apply_outlier_aware_ptq,
    apply_rotation_lsq, apply_qvlm_ptq 
)

logging.basicConfig(level=logging.INFO, format='INFO: %(message)s')

# --- CORRUPTION PIPELINE ---

# CLIP Constants
CLIP_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
CLIP_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
HAS_IMAGECORRUPTIONS = True
def _get_stats(device):
    return CLIP_MEAN.to(device), CLIP_STD.to(device)

def tensor_to_numpy_uint8(x_batch):
    """
    Converts Normalized Tensor (B,C,H,W) -> Denormalized Numpy (B,H,W,C) in [0, 255]
    """
    mean, std = _get_stats(x_batch.device)
    x_denorm = x_batch * std + mean
    x_denorm = x_denorm.clamp(0, 1) * 255.0
    x_denorm = x_denorm.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
    return x_denorm

def numpy_uint8_to_tensor(x_np, device):
    """
    Converts Numpy (B,H,W,C) -> Normalized Tensor (B,C,H,W)
    """
    x_tensor = torch.from_numpy(x_np).float().to(device) / 255.0
    x_tensor = x_tensor.permute(0, 3, 1, 2)
    mean, std = _get_stats(device)
    return (x_tensor - mean) / std

# --- CORRUPTION FUNCTIONS ---

def apply_clean(x):
    return x

def _apply_imagecorruptions_lib(x_batch, corruption_name, severity=1):
    """Uses the standard imagecorruptions library."""
    images_np = tensor_to_numpy_uint8(x_batch)
    corrupted_list = []
    
    for i in range(images_np.shape[0]):
        img = images_np[i]
        # corrupt() expects (H,W,C) uint8
        img_c = corrupt(img, corruption_name=corruption_name, severity=severity)
        corrupted_list.append(img_c)
        
    batch_c = np.stack(corrupted_list)
    return numpy_uint8_to_tensor(batch_c, x_batch.device)

def _apply_torchvision_fallback(x_batch, corruption_type):
    """Fallback using standard Torchvision transforms."""
    mean, std = _get_stats(x_batch.device)
    x_img = x_batch * std + mean # Denorm to [0,1]
    
    if corruption_type == 'gaussian_noise':
        # Severity 3 equivalent
        noise = torch.randn_like(x_img) * 0.02
        x_out = (x_img + noise).clamp(0, 1)
        
    elif corruption_type == 'defocus_blur':
        # Severity 3 equivalent
        x_out = TF.gaussian_blur(x_img, kernel_size=[7, 7], sigma=[2.5, 2.5])
        
    elif corruption_type == 'brightness':
        # Severity 3 equivalent
        x_out = TF.adjust_brightness(x_img, brightness_factor=1.5).clamp(0, 1)
        
    elif corruption_type == 'contrast':
        # Severity 3 equivalent
        x_out = TF.adjust_contrast(x_img, contrast_factor=1.5).clamp(0, 1)
        
    else:
        x_out = x_img

    return (x_out - mean) / std

# --- WRAPPERS ---

def apply_noise_std(x):
    if HAS_IMAGECORRUPTIONS: return _apply_imagecorruptions_lib(x, 'gaussian_noise')
    return _apply_torchvision_fallback(x, 'gaussian_noise')

def apply_blur_std(x):
    if HAS_IMAGECORRUPTIONS: return _apply_imagecorruptions_lib(x, 'defocus_blur')
    return _apply_torchvision_fallback(x, 'defocus_blur')

def apply_brightness_std(x):
    if HAS_IMAGECORRUPTIONS: return _apply_imagecorruptions_lib(x, 'brightness')
    return _apply_torchvision_fallback(x, 'brightness')

def apply_contrast_std(x):
    if HAS_IMAGECORRUPTIONS: return _apply_imagecorruptions_lib(x, 'contrast')
    return _apply_torchvision_fallback(x, 'contrast')

CORRUPTION_FUNCS = {
    "Clean": apply_clean,
    "Noise": apply_noise_std,
    "Blur": apply_blur_std,
    "Brightness": apply_brightness_std,
    "Contrast": apply_contrast_std
}

# --- REGISTRIES ---
ALL_PTQ_METHODS = {
    "Simple PTQ": (apply_simple_ptq, config.SIMPLE_PTQ_KWARGS),
    "SmoothQuant PTQ": (apply_smoothquant, {}),
    "IGQ-ViT PTQ": (apply_igq_vit, config.IGQ_KWARGS),
    "QwT PTQ": (apply_qwt_ptq, config.QWT_KWARGS),                 
    "APQ-ViT PTQ": (apply_apq_ptq, config.APQ_KWARGS),             
    "Rotation PTQ": (apply_rotation_ptq, {}),                      
    "OutlierAware PTQ": (apply_outlier_aware_ptq, config.OUTLIER_AWARE_KWARGS), 
    "Q-VLM": (apply_qvlm_ptq, config.QVLM_KWARGS)
}

ALL_STANDARD_QAT_METHODS = {
    "QAT": (apply_quantization_aware_training, config.QAT_KWARGS),
    "LSQ": (apply_learned_step_size_quantization, config.LSQ_KWARGS),
    "Rotation + LSQ": (apply_rotation_lsq, config.LSQ_KWARGS), 
    "CosQAT": (apply_cosine_qat, config.COS_QAT_KWARGS),
}

ALL_FIXED_QAT_METHODS = {
    "QAT-LoRA": (apply_qat_lora, config.QAT_LORA_KWARGS),
    "Q-ViT": (apply_qvit, config.QVIT_KWARGS),
}

# --- SETUP ---
CSV_FIELDNAMES = [
    "Model_Key", "Architecture", "Quant_Scope",
    "Suite", "Variant_Name", "Shift_Type",
    "Scenario", "Method", "Run_ID", "Seed", 
    "W_Bits", "A_Bits", "Text_Quantized", "Quant_Config_Str", 
    "Accuracy", "ECE", 
    "Status"
]

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

def init_csv(model_key, quant_scope):
    output_dir = "results_robustness"
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    safe_key = model_key.replace("/", "-")
    filename = os.path.join(output_dir, f"robustness_{safe_key}_{quant_scope}_{timestamp}.csv")
    with open(filename, mode='w', newline='') as f:
        csv.DictWriter(f, fieldnames=CSV_FIELDNAMES).writeheader()
    print(f"Logging to: {filename}")
    return filename

def log_row(filename, data):
    try:
        with open(filename, mode='a', newline='') as f:
            csv.DictWriter(f, fieldnames=CSV_FIELDNAMES).writerow(data)
    except Exception as e:
        print(f"Log Error: {e}")

# --- FEATURE HELPERS ---
def get_text_features(model, tokenizer, classes, template, device):
    """
    Generates ID Features only.
    """
    with torch.no_grad(), torch.amp.autocast('cuda', enabled=config.USE_AMP):
        aggregated_feats = None
        for temp in config.CLIP_IMAGENET_TEMPLATES:
            texts = [temp.format(c) for c in classes]
            curr_feats = []
            for j in range(0, len(texts), 500):
                tk = tokenizer(texts[j:j+500]).to(device)
                f = F.normalize(model.encode_text(tk), dim=-1)
                curr_feats.append(f)
            feats = torch.cat(curr_feats, dim=0)
            aggregated_feats = feats if aggregated_feats is None else aggregated_feats + feats
        f_id = F.normalize(aggregated_feats / len(config.CLIP_IMAGENET_TEMPLATES), dim=-1)

    return f_id

# --- EVALUATION LOGIC ---

@torch.no_grad()
def evaluate_metrics(model, loader, corruption_fn, f_id, device):
    """
    Calculates Accuracy & ECE on the (potentially corrupted) loader.
    """
    model.eval()
    logit_scale = model.logit_scale.exp().to(device).clamp(max=100.0)
    
    all_logits = []
    all_labels = []
    
    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        
        # Apply synthetic corruption if defined
        if corruption_fn:
            images = corruption_fn(images)
            
        labels = labels.to(device, non_blocking=True)
        
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            img_feats = F.normalize(model.encode_image(images), dim=-1)
            
            # Classification
            raw_cos_id = img_feats @ f_id.T
            logits = logit_scale * raw_cos_id
            
            all_logits.append(logits.float().cpu())
            all_labels.append(labels.cpu())
            
    # Calc Metrics
    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)
    
    acc = (all_logits.argmax(dim=-1) == all_labels).float().mean().item()
    ece = calculate_calibration_metrics(all_logits, all_labels)
    
    return acc, ece

def run_suite_evaluation(model, tokenizer, preprocess, device, csv_path, base_log_params):
    """
    Iterates over all suites defined in ROBUSTNESS_CONFIG (ImageNet, CIFAR) 
    and logs results. Used for both FP32 and Quantized models.
    """
    for suite_name, suite_cfg in config.ROBUSTNESS_CONFIG.items():
        
        # A. NATURAL SHIFTS (ImageNet-V2, etc)
        if "shifted_datasets" in suite_cfg:
            for ds_name in suite_cfg["shifted_datasets"]:
                # Load dataset dynamically
                _, ds_loader, ds_classes, ds_temp = data_setup.get_dataset_loaders(ds_name, preprocess, get_train=False)
                if not ds_loader: continue
                
                # Features for THIS variant
                f_id = get_text_features(model, tokenizer, ds_classes, ds_temp, device)
                
                acc, ece = evaluate_metrics(
                    model, ds_loader, None, 
                    f_id, device
                )
                
                row = base_log_params.copy()
                row.update({
                    "Suite": suite_name,
                    "Variant_Name": ds_name,
                    "Shift_Type": "Natural",
                    "Accuracy": acc,
                    "ECE": ece,
                    "Status": "Success"
                })
                log_row(csv_path, row)
                    
        # B. SYNTHETIC SHIFTS (CIFAR + Noise)
        elif "corruptions" in suite_cfg:
            src_id = suite_cfg['source_id']
            _, src_loader, src_classes, src_temp = data_setup.get_dataset_loaders(src_id, preprocess, get_train=False)
            if not src_loader: continue
            
            # Features for Source (static classes)
            f_id = get_text_features(model, tokenizer, src_classes, src_temp, device)
            
            for corr_name in suite_cfg['corruptions']:
                corr_fn = CORRUPTION_FUNCS.get(corr_name)
                
                acc, ece = evaluate_metrics(
                    model, src_loader, corr_fn,
                    f_id, device
                )
                
                row = base_log_params.copy()
                row.update({
                    "Suite": suite_name,
                    "Variant_Name": f"{src_id}_{corr_name}",
                    "Shift_Type": "Synthetic",
                    "Accuracy": acc,
                    "ECE": ece,
                    "Status": "Success"
                })
                log_row(csv_path, row)

def main():
    model_key = config._target_key
    quant_scope = os.environ.get("TARGET_QUANT_SCOPE", "ALL")
    active_txt_modes = config.TEXT_QUANTIZATION_MODES
    
    print(f"\n{'='*80}\nSTARTING ROBUSTNESS BENCHMARK (NO OOD): {model_key}\n{'='*80}")

    csv_path = init_csv(model_key, quant_scope)
    data_setup.set_seed(config.RANDOM_SEED)
    model_fp32, tokenizer, preprocess = data_setup.get_model_and_tokenizer()
    model_fp32.eval().cpu()
    
    # 1. Prepare Proxy Scenarios (For Calibration)
    proxy_scenarios = []
    if config.ENABLE_PROXY_EXPERIMENTS:
        for p_name in config.ACTIVE_PROXY_DATASETS:
            path, _ = config.PROXY_DATASETS.get(p_name, (None, None))
            if path:
                ldr = data_setup.create_train_iterable(str(path), preprocess, config.BATCH_SIZE)
                if ldr:
                    proxy_scenarios.append({
                        "name": p_name, "type": "ptq", "loader": ldr, "cal": ldr, "prompts": None
                    })

    # --- EXPERIMENT LOOP ---
    for run_idx in range(config.NUM_RUNS):
        curr_seed = config.RANDOM_SEED + run_idx
        data_setup.set_seed(curr_seed)
        
        for qt_bool in active_txt_modes:
            q_scope_str = "VISUAL_TEXT" if qt_bool else "VISUAL_ONLY"
            
            # --- 1. RUN FP32 BASELINE (Once per Scope/Seed) ---
            print(f"[{run_idx}] Running FP32 Baseline for {q_scope_str}...")
            try:
                model_fp32.to(config.TARGET_DEVICE)
                
                fp32_log_params = {
                    "Model_Key": model_key, "Architecture": config.CLIP_MODEL_ARCHITECTURE,
                    "Quant_Scope": q_scope_str,
                    "Scenario": "Baseline", "Method": "FP32", 
                    "Run_ID": run_idx, "Seed": curr_seed,
                    "W_Bits": 32, "A_Bits": 32, "Text_Quantized": False,
                    "Quant_Config_Str": "FP32"
                }
                
                run_suite_evaluation(
                    model_fp32, tokenizer, preprocess, 
                    config.TARGET_DEVICE, csv_path, fp32_log_params
                )
                
            except Exception as e:
                print(f"FP32 Baseline Failed: {e}")
                import traceback
                traceback.print_exc()
            finally:
                model_fp32.cpu()
                cleanup()

            # --- 2. RUN QUANTIZATION SCENARIOS ---
            for sc in proxy_scenarios:
                
                # Filter Active Methods
                methods_to_run = []
                
                # 1. Add PTQ Methods
                """for m in config.ACTIVE_PTQ_METHODS:
                    if m in ALL_PTQ_METHODS: 
                        methods_to_run.append((m, ALL_PTQ_METHODS[m]))"""

                # 2. Add Standard QAT Methods 
                for m in config.ACTIVE_STANDARD_QAT_METHODS:
                    if m in ALL_STANDARD_QAT_METHODS: 
                        methods_to_run.append((m, ALL_STANDARD_QAT_METHODS[m]))

                # 3. Add Fixed QAT Methods 
                for m in config.ACTIVE_FIXED_QAT_METHODS:
                    if m in ALL_FIXED_QAT_METHODS: 
                        methods_to_run.append((m, ALL_FIXED_QAT_METHODS[m]))
                
                for m_label, (apply_fn, base_kwargs) in methods_to_run:
                    for w_bit, a_bit in config.BIT_WIDTHS_TO_TEST:
                        cleanup()
                        config_str = f"W{w_bit}A{a_bit}_Txt{int(qt_bool)}"
                        print(f"[{run_idx}] {sc['name']} | {m_label} | {config_str}")
                        
                        # Initialize eval_model to None to prevent 'finally' block crash
                        eval_model = None
                        
                        try:
                            # Quantize
                            temp_model = copy.deepcopy(model_fp32)
                            
                            # Determine if this is a QAT method to add specific arguments
                            extra_args = {}
                            is_qat = (m_label in ALL_STANDARD_QAT_METHODS) or (m_label in ALL_FIXED_QAT_METHODS)
                            
                            if is_qat:
                                extra_args['training_dataloader'] = sc['loader']
                                extra_args['teacher'] = model_fp32 # Use the FP32 model (on CPU or GPU handled by function)

                            eval_model = apply_fn(
                                temp_model, 
                                calibration_dataloader=sc['loader'],
                                target_device=config.TARGET_DEVICE, 
                                tokenizer=tokenizer,
                                prompts=sc.get('prompts'), 
                                quantize_text=qt_bool,
                                weight_bits=w_bit, 
                                act_bits=a_bit, 
                                **extra_args, # Pass QAT args safely
                                **base_kwargs
                            )
                            
                            quant_log_params = {
                                "Model_Key": model_key, "Architecture": config.CLIP_MODEL_ARCHITECTURE,
                                "Quant_Scope": q_scope_str,
                                "Scenario": sc['name'],
                                "Method": m_label, "Run_ID": run_idx, "Seed": curr_seed,
                                "W_Bits": w_bit, "A_Bits": a_bit, "Text_Quantized": qt_bool,
                                "Quant_Config_Str": config_str
                            }
                            
                            # Evaluate
                            run_suite_evaluation(
                                eval_model, tokenizer, preprocess, 
                                config.TARGET_DEVICE, csv_path, quant_log_params
                            )

                        except Exception as e:
                            print(f"CRASH in Quantization: {e}")
                            import traceback
                            traceback.print_exc()
                        finally:
                            # Safe deletion
                            if eval_model is not None:
                                del eval_model
                            cleanup()
if __name__ == "__main__":
    main()