import os
import sys
import gc
import copy
import csv
import torch
import numpy as np
import datetime
import warnings
from tqdm import tqdm

# --- SPEED & MEMORY SETTINGS ---
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["HF_HUB_OFFLINE"] = "1"
os.environ["HF_DATASETS_OFFLINE"] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

torch.set_float32_matmul_precision('high')
warnings.filterwarnings("ignore", category=UserWarning, module="webdataset")

# --- IMPORTS ---
sys.path.append(os.getcwd())
import scripts.config as config
import scripts.data_setup as data_setup
import open_clip

# Import Methods
from quantization.apply import (
    apply_simple_ptq, apply_smoothquant, apply_igq_vit, apply_qvit,
    apply_quantization_aware_training, apply_learned_step_size_quantization, 
    apply_qat_lora, apply_cosine_qat,
    apply_qwt_ptq, apply_apq_ptq, apply_rotation_ptq, apply_outlier_aware_ptq,
    apply_rotation_lsq, apply_qvlm_ptq 
)

# ==============================================================================
# 1. SETUP & CONFIGURATION
# ==============================================================================

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
config.TARGET_DEVICE = DEVICE 

# Force Single Process to prevent logging bugs
config.NUM_WORKERS = 0
config.EVAL_NUM_WORKERS = 0

# Ranks
UPPER_BOUNDS = [8, 12, 16, 24, 32, 64, 128, 256, 512, 768]
CUMULATIVE_INTERVALS = [(0, b) for b in UPPER_BOUNDS]

# Bits
BIT_CONFIGS = [(8, 8), (6, 8), (4, 8)]

# CSV Headers
CSV_SATURATION_FIELDS = ["Model_Arch", "Method_Family", "Method", "W_Bits", "A_Bits", "Cumulative_Rank", "Avg_Cosine_Sim"]
CSV_SPECTRAL_FIELDS = ["Model_Arch", "Method_Family", "Method", "W_Bits", "A_Bits", "Rank_Index", "Sigma", "SQNR"]

# Methods
ALL_FAMILIES = [
    ("PTQ", {
        "Simple PTQ": (apply_simple_ptq, config.SIMPLE_PTQ_KWARGS),
        "Rot-PTQ": (apply_rotation_ptq, {}),
        "SmoothQuant": (apply_smoothquant, {}),
        "IGQ-ViT": (apply_igq_vit, config.IGQ_KWARGS),
        "QwT": (apply_qwt_ptq, config.QWT_KWARGS),
        "APQ-ViT": (apply_apq_ptq, config.APQ_KWARGS),
        "OutlierAware": (apply_outlier_aware_ptq, config.OUTLIER_AWARE_KWARGS),
        "Q-VLM": (apply_qvlm_ptq, config.QVLM_KWARGS)
    }, 'ptq'),
    ("Standard_QAT", {
        "QAT": (apply_quantization_aware_training, config.QAT_KWARGS),
        "LSQ": (apply_learned_step_size_quantization, config.LSQ_KWARGS),
        "Rot-LSQ": (apply_rotation_lsq, config.LSQ_KWARGS),
        "CosQAT": (apply_cosine_qat, config.COS_QAT_KWARGS),
    }, 'qat'),
    ("Fixed_QAT", {
        "QAT-LoRA": (apply_qat_lora, config.QAT_LORA_KWARGS),
        "Q-ViT": (apply_qvit, config.QVIT_KWARGS),
    }, 'qat')
]

QUANTIZE_TEXT = False 

# ==============================================================================
# 2. UTILS
# ==============================================================================

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

def get_target_layer_name(model):
    """
    Dynamically finds the last visual block name based on architecture.
    Supports: OpenAI CLIP, SigLIP, EVA, ConvNeXt, CoCa.
    """
    vis = model.visual
    
    # 1. Standard OpenAI CLIP ViT
    if hasattr(vis, 'transformer') and hasattr(vis.transformer, 'resblocks'):
        return f"visual.transformer.resblocks.{len(vis.transformer.resblocks) - 1}"
    
    # 2. SigLIP / EVA / CoCa (Common 'trunk.blocks' structure)
    if hasattr(vis, 'trunk') and hasattr(vis.trunk, 'blocks'):
        return f"visual.trunk.blocks.{len(vis.trunk.blocks) - 1}"
        
    # 3. ConvNeXt (Stages structure)
    if hasattr(vis, 'trunk') and hasattr(vis.trunk, 'stages'):
        last_stage_idx = len(vis.trunk.stages) - 1
        last_stage = vis.trunk.stages[last_stage_idx]
        if hasattr(last_stage, 'blocks'):
            last_block_idx = len(last_stage.blocks) - 1
            return f"visual.trunk.stages.{last_stage_idx}.blocks.{last_block_idx}"
            
    # Fallback / Debug
    print("WARNING: Could not detect standard ViT/ConvNeXt layer structure.")
    print("Available top-level visual attributes:", dir(vis))
    raise ValueError(f"Could not determine target layer for {config.MODEL_CONFIG['arch']}")

def init_csvs(model_name, pretrained_name, quant_mode):
    os.makedirs("results", exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    
    s_model = model_name.replace("/", "-").replace(" ", "_")
    s_pre = pretrained_name.replace("/", "-").replace(" ", "_")
    
    safe_name = f"{s_model}_{s_pre}_{quant_mode}_{timestamp}"
    
    f_sat = f"results/saturation_sweep_{safe_name}.csv"
    f_spec = f"results/spectral_stats_{safe_name}.csv"
    
    with open(f_sat, mode='w', newline='') as f:
        f.truncate(0)
        csv.DictWriter(f, fieldnames=CSV_SATURATION_FIELDS).writeheader()
        
    with open(f_spec, mode='w', newline='') as f:
        f.truncate(0)
        csv.DictWriter(f, fieldnames=CSV_SPECTRAL_FIELDS).writeheader()
        
    return f_sat, f_spec

def log_saturation(filename, row):
    with open(filename, mode='a', newline='') as f:
        csv.DictWriter(f, fieldnames=CSV_SATURATION_FIELDS).writerow(row)

def log_spectral(filename, model_arch, family, method, w, a, sigmas, sqnr=None):
    rows = []
    sigmas_list = sigmas.tolist()
    sqnr_list = sqnr.tolist() if sqnr is not None else ["N/A"] * len(sigmas_list)
    
    # Only log top 768 ranks to keep CSV manageable, even if model dim is larger
    limit = min(len(sigmas_list), 768)
    
    for r in range(limit):
        rows.append({
            "Model_Arch": model_arch, "Method_Family": family, "Method": method,
            "W_Bits": w, "A_Bits": a, "Rank_Index": r,
            "Sigma": sigmas_list[r], "SQNR": sqnr_list[r]
        })
    
    with open(filename, mode='a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=CSV_SPECTRAL_FIELDS)
        writer.writerows(rows)

# ==============================================================================
# 3. PROFILER LOGIC
# ==============================================================================

class UnifiedProfiler:
    def __init__(self, model, layer_name):
        self.model = model
        self.target_module = None
        # Recurse to find layer by dot-string path
        parts = layer_name.split('.')
        curr = model
        try:
            for part in parts:
                curr = getattr(curr, part)
            self.target_module = curr
        except AttributeError:
            raise ValueError(f"Layer path {layer_name} invalid for model.")

    def capture_activations(self, dataloader):
        activations = []
        def hook(module, input, output):
            # Handle different output shapes
            if output.ndim == 3:
                # ViT: (B, L, C). Take CLS (0)
                x = output[:, 0, :].float().detach().cpu()
            elif output.ndim == 4:
                # ConvNeXt/CNN: (B, C, H, W). Global Avg Pool.
                x = output.mean(dim=(2, 3)).float().detach().cpu()
            else:
                x = output.float().detach().cpu()
            activations.append(x)

        handle = self.target_module.register_forward_hook(hook)
        self.model.eval()
        
        with torch.inference_mode(), torch.amp.autocast('cuda'):
            for imgs, _ in tqdm(dataloader, desc="Capturing Activations", leave=False):
                self.model.encode_image(imgs.to(DEVICE, non_blocking=True))
                
        handle.remove()
        return torch.cat(activations, dim=0)

    def run_saturation_sweep(self, dataloader, intervals, ref_model=None):
        results_sum = {f"0-{e}": 0.0 for (_, e) in intervals}
        total_samples = 0
        current_end = 0
        
        def mask_hook(module, input, output):
            orig_shape = output.shape
            
            # 1. Flatten / Pool based on shape
            if output.ndim == 3:
                # ViT: Process all tokens for SVD to be accurate
                x = output.reshape(-1, output.shape[-1])
            elif output.ndim == 4:
                # CNN: Reshape (B, C, H, W) -> (B*H*W, C)
                # This treats every spatial location as a sample
                b, c, h, w = output.shape
                x = output.permute(0, 2, 3, 1).reshape(-1, c)
            else:
                x = output
            
            # 2. SVD
            U, S, Vh = torch.linalg.svd(x.float(), full_matrices=False)
            
            # 3. Mask
            mask = torch.zeros_like(S)
            if mask.ndim == 1: mask[:current_end] = 1
            else: mask[:, :current_end] = 1
            
            # 4. Reconstruct
            x_recon = U @ torch.diag_embed(S * mask) @ Vh
            
            # 5. Reshape back
            if output.ndim == 4:
                # (B*H*W, C) -> (B, H, W, C) -> (B, C, H, W)
                b, c, h, w = orig_shape
                x_recon = x_recon.reshape(b, h, w, c).permute(0, 3, 1, 2)
            else:
                x_recon = x_recon.reshape(orig_shape)
                
            return x_recon.type_as(output)

        self.model.eval()
        if ref_model: ref_model.eval()

        with torch.inference_mode(), torch.amp.autocast('cuda'):
            for images, _ in tqdm(dataloader, desc="Saturation Sweep", leave=False):
                images = images.to(DEVICE, non_blocking=True)
                B = images.shape[0]

                target_model = ref_model if ref_model else self.model
                gt_emb = target_model.encode_image(images)
                gt_emb = gt_emb / gt_emb.norm(dim=-1, keepdim=True)
                
                for (_, end) in intervals:
                    current_end = end
                    handle = self.target_module.register_forward_hook(mask_hook)
                    
                    pred_emb = self.model.encode_image(images)
                    pred_emb = pred_emb / pred_emb.norm(dim=-1, keepdim=True)
                    
                    sim = (pred_emb * gt_emb).sum(dim=-1).sum().item()
                    results_sum[f"0-{end}"] += sim
                    handle.remove()
                
                total_samples += B

        return {k: v / total_samples for k, v in results_sum.items()}

def compute_sqnr(X_fp32, X_quant, V_fp32):
    # Calculations on GPU
    X_fp32_g = X_fp32.to(DEVICE)
    X_quant_g = X_quant.to(DEVICE)
    V_fp32_g = V_fp32.to(DEVICE)
    
    mean_fp32 = X_fp32_g.mean(dim=0, keepdim=True)
    X_fp32_c = X_fp32_g - mean_fp32
    X_quant_c = X_quant_g - mean_fp32
    
    P_clean = X_fp32_c @ V_fp32_g
    P_quant = X_quant_c @ V_fp32_g
    
    signal_power = (P_clean ** 2).mean(dim=0)
    noise_power = ((P_clean - P_quant) ** 2).mean(dim=0)
    
    sqnr = 10 * torch.log10((signal_power + 1e-10) / (noise_power + 1e-10))
    return sqnr.cpu()

# ==============================================================================
# 4. MAIN LOOP
# ==============================================================================

def main():
    model_key = config.MODEL_CONFIG['arch']
    pretrained_key = config.MODEL_CONFIG['data']
    mode_str = "VisualText" if QUANTIZE_TEXT else "VisualOnly"
    
    print(f"--- FAST BENCHMARK SUITE [AMP ENABLED] ---")
    print(f"Target: {model_key} | {pretrained_key}")
    
    sat_csv, spec_csv = init_csvs(model_key, pretrained_key, mode_str)
    data_setup.set_seed(config.RANDOM_SEED)

    # 1. Load Resources (CPU First)
    original_target = config.TARGET_DEVICE
    config.TARGET_DEVICE = "cpu"
    master_model_cpu, tokenizer, preprocess = data_setup.get_model_and_tokenizer()
    config.TARGET_DEVICE = original_target
    
    # Determine Correct Layer Dynamically
    target_layer_name = get_target_layer_name(master_model_cpu)
    print(f"Detected Target Layer: {target_layer_name}")
    
    cc3m_path_str = str(config.SHARD_PATH_CC3M)
    calib_loader = data_setup.create_train_iterable(cc3m_path_str, preprocess, config.BATCH_SIZE)
    _, full_val_loader, _, _ = data_setup.get_dataset_loaders("imagenet1kval", preprocess, get_train=False)

    # 2. FP32 Baseline
    print("\n--- Processing FP32 Baseline ---")
    baseline_model = copy.deepcopy(master_model_cpu).to(DEVICE)
    profiler = UnifiedProfiler(baseline_model, target_layer_name)
    
    # A. Capture & Spectrum
    X_fp32 = profiler.capture_activations(full_val_loader) 
    X_gpu = X_fp32.to(DEVICE)
    X_centered = X_gpu - X_gpu.mean(dim=0, keepdim=True)
    _, S_fp32, Vh_fp32 = torch.linalg.svd(X_centered, full_matrices=False)
    V_fp32 = Vh_fp32.T 
    
    log_spectral(spec_csv, model_key, "Baseline", "FP32", 32, 32, S_fp32.cpu(), sqnr=None)
    
    # B. Saturation
    sat_metrics = profiler.run_saturation_sweep(full_val_loader, CUMULATIVE_INTERVALS, ref_model=None)
    for k, v in sat_metrics.items():
        log_saturation(sat_csv, {
            "Model_Arch": model_key, "Method_Family": "Baseline", "Method": "FP32",
            "W_Bits": 32, "A_Bits": 32, "Cumulative_Rank": int(k.split('-')[1]), "Avg_Cosine_Sim": v
        })
    
    del baseline_model, profiler, X_gpu, S_fp32, Vh_fp32
    cleanup()

    # 3. Method Loop
    for fam_name, method_dict, m_type in ALL_FAMILIES:
        for method_name, (apply_fn, base_kwargs) in method_dict.items():
            for (w_bit, a_bit) in BIT_CONFIGS:
                print(f"\n>>> {method_name} (W{w_bit}A{a_bit})")
                try:
                    student = copy.deepcopy(master_model_cpu).to(DEVICE)
                    local_teacher = None
                    if m_type == 'qat':
                        local_teacher = copy.deepcopy(master_model_cpu).to(DEVICE)
                        for p in local_teacher.parameters(): p.requires_grad = False
                        local_teacher.eval()

                    run_kwargs = base_kwargs.copy()
                    run_kwargs.update({
                        'weight_bits': w_bit, 'act_bits': a_bit, 'target_device': DEVICE,
                        'tokenizer': tokenizer, 'quantize_text': QUANTIZE_TEXT, 'prompts': None
                    })
                    
                    if m_type == 'qat':
                        dist_args = config.ACTIVE_DISTILLATION_MODES.get('Distill', {})
                        run_kwargs.update(dist_args)
                        student = apply_fn(student, training_dataloader=calib_loader, 
                                         calibration_dataloader=calib_loader, teacher=local_teacher, **run_kwargs)
                    else:
                        student = apply_fn(student, calibration_dataloader=calib_loader, **run_kwargs)

                    # --- Analysis ---
                    profiler = UnifiedProfiler(student, target_layer_name)
                    
                    # 1. Capture & Spectrum
                    X_quant = profiler.capture_activations(full_val_loader) 
                    X_quant_g = X_quant.to(DEVICE)
                    X_q_centered = X_quant_g - X_quant_g.mean(dim=0, keepdim=True)
                    _, S_quant, _ = torch.linalg.svd(X_q_centered, full_matrices=False)
                    
                    # 2. SQNR
                    sqnr_vals = compute_sqnr(X_fp32, X_quant, V_fp32)
                    
                    log_spectral(spec_csv, model_key, fam_name, method_name, w_bit, a_bit, S_quant.cpu(), sqnr_vals)
                    
                    # 3. Saturation
                    ref_model = local_teacher if local_teacher else copy.deepcopy(master_model_cpu).to(DEVICE).eval()
                    sat_metrics = profiler.run_saturation_sweep(full_val_loader, CUMULATIVE_INTERVALS, ref_model=ref_model)
                    
                    for k, v in sat_metrics.items():
                        log_saturation(sat_csv, {
                            "Model_Arch": model_key, "Method_Family": fam_name, "Method": method_name,
                            "W_Bits": w_bit, "A_Bits": a_bit, "Cumulative_Rank": int(k.split('-')[1]), "Avg_Cosine_Sim": v
                        })

                    del student, local_teacher, profiler, ref_model, X_quant, X_quant_g, S_quant, sqnr_vals
                    cleanup()

                except Exception as e:
                    print(f"Failed: {e}")
                    cleanup()
                    continue

    print(f"\nDone.")
    print(f"Saturation Data: {sat_csv}")
    print(f"Spectral Data:   {spec_csv}")

if __name__ == "__main__":
    main()