import os
import sys
import copy
import csv
import torch
import datetime
import open_clip
import re
import webdataset as wds
import logging
import ctypes
import gc
from pathlib import Path
from tqdm import tqdm
from torchvision import datasets
from torchvision.datasets.folder import default_loader
from torch.utils.data import DataLoader, Dataset
from transformers import XLMRobertaTokenizer
from open_clip.tokenizer import HFTokenizer

# ==============================================================================
#  1. USER CONFIGURATION
# ==============================================================================

# --- HARDWARE ---
TARGET_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = True 
SEED = 42

# --- PATHS ---
try:
    PROJECT_ROOT = Path(os.getcwd())
except NameError:
    PROJECT_ROOT = Path('.')

COUNTERANIMAL_PATH = os.path.join(PROJECT_ROOT, "datasets/counteranimal/LAION-final")
DATA_QUANT_ROOT = PROJECT_ROOT / "quantization/data_quantization/"

PROXY_SHARDS = {
    "CC3M": os.path.join(DATA_QUANT_ROOT, "cc3m_shards_00000.tar"),
    "YFCC": os.path.join(DATA_QUANT_ROOT, "yfcc_shard_3_9k_filtered.tar"),
    "SBU":  os.path.join(DATA_QUANT_ROOT, "sbu_shard_2k_filtered.tar"),
}

# --- BENCHMARK SETTINGS ---
BIT_WIDTHS = [(8, 8), (6, 8), (4, 8)]
TEXT_QUANT_MODES = [False, True] 
EVAL_BATCH_SIZE = 128
TRAIN_BATCH_SIZE = 64
TEMPLATE = "a photo of a {}"

# --- FULL MODEL ZOO ---
MODEL_ZOO = {
    "CLIP_ViT-B-32_OpenAI":    {'type': 'open_clip', 'arch': 'ViT-B-32-quickgelu', 'data': 'openai'},
    "CLIP_ViT-B-32_Laion2B":   {'type': 'open_clip', 'arch': 'ViT-B-32', 'data': 'laion2b_s34b_b79k'},
    "SigLIP_ViT-B-16":         {'type': 'open_clip', 'arch': 'ViT-B-16-SigLIP', 'data': 'webli'},
    "ALIGN_RoBERTa_ViT-B":     {'type': 'open_clip', 'arch': 'xlm-roberta-base-ViT-B-32', 'data': 'laion5b_s13b_b90k'},
    "CLIP_ViT-L-14_Laion2B":   {'type': 'open_clip', 'arch': 'ViT-L-14', 'data': 'laion2b_s32b_b82k'},
    "CLIP_ViT-L-14_OpenAI":    {'type': 'open_clip', 'arch': 'ViT-L-14-quickgelu', 'data': 'openai'},
    "EVA02_B-16":              {'type': 'open_clip', 'arch': 'EVA02-B-16', 'data': 'merged2b_s8b_b131k'},
    "ConvNeXt_Base":           {'type': 'open_clip', 'arch': 'convnext_base', 'data': 'laion400m_s13b_b51k'},
    "DFN_ViT-B-32":            {'type': 'open_clip', 'arch': 'ViT-B-32', 'data': 'datacomp_xl_s13b_b90k'},
    "CoCa_ViT-B-32":           {'type': 'open_clip', 'arch': 'coca_ViT-B-32', 'data': 'laion2b_s13b_b90k'},
}

# ==============================================================================
#  2. LIBRARY PATCHING
# ==============================================================================
import scripts.config as lib_config
lib_config.TARGET_DEVICE = TARGET_DEVICE
lib_config.USE_AMP = USE_AMP
# QAT functions
from quantization.apply import apply_simple_ptq, apply_rotation_lsq, apply_quantization_aware_training

# ==============================================================================
#  3. DATA SETUP UTILS
# ==============================================================================

def cleanup():
    gc.collect()
    torch.cuda.empty_cache()
    try:
        ctypes.CDLL("libc.so.6").malloc_trim(0)
    except: pass

class SplitDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        image = default_loader(path)
        if self.transform: image = self.transform(image)
        return image, label

def load_and_split_counteranimal(preprocess):
    if not os.path.exists(COUNTERANIMAL_PATH): raise FileNotFoundError(f"Path not found: {COUNTERANIMAL_PATH}")
    full_ds = datasets.ImageFolder(COUNTERANIMAL_PATH)
    normal_samples, counter_samples = [], []
    for path, label_idx in full_ds.samples:
        if "common-" in path: normal_samples.append((path, label_idx))
        elif "counter-" in path: counter_samples.append((path, label_idx))
    
    clean_class_names = []
    for raw_cls in full_ds.classes:
        parts = raw_cls.split(' ', 1)
        name = parts[1].split(',')[0].strip() if len(parts) > 1 else raw_cls
        clean_class_names.append(name)

    normal_loader = DataLoader(SplitDataset(normal_samples, transform=preprocess), batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    counter_loader = DataLoader(SplitDataset(counter_samples, transform=preprocess), batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    
    import random
    cal_samples = random.sample(normal_samples, min(len(normal_samples), 512))
    cal_loader = DataLoader(SplitDataset(cal_samples, transform=preprocess), batch_size=64, shuffle=True, num_workers=2)

    return normal_loader, counter_loader, cal_loader, clean_class_names

def create_proxy_loader(shard_path, preprocess):
    if not os.path.exists(shard_path): return None
    def map_fn(sample): return preprocess(sample[0]), sample[1]
    dataset = (wds.WebDataset(shard_path, resampled=True).shuffle(1000).decode("pil")
               .to_tuple("jpg", "txt").map(map_fn).batched(TRAIN_BATCH_SIZE))
    return wds.WebLoader(dataset, batch_size=None, num_workers=2)

def whitespace_clean(text): return re.sub(r'\s+', ' ', text).strip()

def get_model_and_tokenizer(model_cfg):
    arch, data = model_cfg['arch'], model_cfg['data']
    print(f"Loading {arch}...")
    model, _, preprocess = open_clip.create_model_and_transforms(arch, pretrained=data)
    
    # --- ALIGN / RoBERTa FIX ---
    if "roberta" in arch.lower():
        try:
            hf_name = "xlm-roberta-base" if "base" in arch.lower() else "xlm-roberta-large"
            # Attempt local load if available, else download
            path_to_load = f"./{hf_name}" if os.path.exists(f"./{hf_name}") else hf_name
            
            internal_tokenizer = XLMRobertaTokenizer.from_pretrained(path_to_load)
            
            # Manually inject into OpenCLIP HFTokenizer wrapper
            tokenizer = HFTokenizer.__new__(HFTokenizer)
            tokenizer.tokenizer = internal_tokenizer
            tokenizer.context_length = 77
            tokenizer.clean_fn = whitespace_clean
            tokenizer.tokenizer_mode = 'slow' 
            tokenizer.strip_sep_token = False # <--- CRITICAL FIX FOR YOUR CRASH
        except Exception as e:
            print(f"Tokenizer fallback failed: {e}")
            tokenizer = open_clip.get_tokenizer(arch)
    else:
        tokenizer = open_clip.get_tokenizer(arch)
        
    return model.to(TARGET_DEVICE).eval(), tokenizer, preprocess

def get_text_features(model, tokenizer, class_names):
    texts = [TEMPLATE.format(c) for c in class_names]
    with torch.no_grad(), torch.amp.autocast('cuda', enabled=USE_AMP):
        tokens = tokenizer(texts).to(TARGET_DEVICE)
        feats = model.encode_text(tokens)
        feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats

def evaluate_accuracy(model, loader, text_feats):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad(), torch.amp.autocast('cuda', enabled=USE_AMP):
        for images, labels in tqdm(loader, leave=False):
            images = images.to(TARGET_DEVICE, non_blocking=True)
            labels = labels.to(TARGET_DEVICE)
            img_feats = model.encode_image(images)
            img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)
            logits = 100.0 * img_feats @ text_feats.T
            correct += logits.argmax(dim=-1).eq(labels).sum().item()
            total += len(labels)
    return correct / total if total > 0 else 0.0

def calculate_all_metrics(acc_n, acc_c, fp32_n, fp32_c):
    if acc_n > 0: rel_gap_q = (acc_n - acc_c) / acc_n
    else: rel_gap_q = 0.0

    rel_gap_fp32 = 0.0
    if fp32_n > 0: rel_gap_fp32 = (fp32_n - fp32_c) / fp32_n

    drop_n, drop_c = 0.0, 0.0
    if fp32_n > 0: drop_n = (fp32_n - acc_n) / fp32_n
    if fp32_c > 0: drop_c = (fp32_c - acc_c) / fp32_c

    delta_rsg = rel_gap_q - rel_gap_fp32
    added_vuln = drop_c - drop_n

    return rel_gap_q, delta_rsg, drop_n, drop_c, added_vuln

def print_result_row(method, scope, n, c, d_rsg, vul):
    print(f"   | {method:<20} | {scope:<8} | N={n:.3f} C={c:.3f} | dRSG: {d_rsg:+.3f} | Vuln: {vul:+.3f}")

# ==============================================================================
#  5. MAIN
# ==============================================================================

def main():
    print(f"\n{'='*80}")
    print(f"SPURIOUS CORRELATION BENCHMARK | DUAL METRICS (dRSG + Vuln)")
    print(f"Path: {COUNTERANIMAL_PATH}")
    print(f"{'='*80}\n")

    os.makedirs("results", exist_ok=True)
    csv_name = f"results/spurious_final_{datetime.datetime.now().strftime('%H%M')}.csv"
    
    csv_fields = [
        "Model", "Proxy_Data", "Method", "Bits", "Quant_Scope",
        "Normal_Acc", "Counter_Acc", 
        "Rel_Gap_Q", "Delta_RSG", 
        "Drop_Normal", "Drop_Counter", "Added_Vuln"
    ]
    
    with open(csv_name, 'w', newline='') as f:
        csv.DictWriter(f, fieldnames=csv_fields).writeheader()

    for model_key, model_cfg in MODEL_ZOO.items():
        print(f"\n{'#'*80}")
        print(f"MODEL: {model_key}")
        print(f"{'#'*80}")
        lib_config.MODEL_CONFIG = model_cfg

        try:
            master_model, tokenizer, preprocess = get_model_and_tokenizer(model_cfg)
            normal_loader, counter_loader, cal_loader, class_names = load_and_split_counteranimal(preprocess)
        except Exception as e:
            print(f"Skipping {model_key}: {e}")
            continue

        # Baseline Eval
        print("\n>> Evaluating FP32 Baseline...")
        fp32_txt = get_text_features(master_model, tokenizer, class_names)
        fp32_n = evaluate_accuracy(master_model, normal_loader, fp32_txt)
        fp32_c = evaluate_accuracy(master_model, counter_loader, fp32_txt)
        
        m1_gap, _, _, _, _ = calculate_all_metrics(fp32_n, fp32_c, fp32_n, fp32_c)
        print(f"   FP32 | Normal: {fp32_n:.4f} | Counter: {fp32_c:.4f} | Base RelGap: {m1_gap:.4f}")

        # Log Baseline
        with open(csv_name, 'a', newline='') as f:
            csv.DictWriter(f, fieldnames=csv_fields).writerow({
                "Model": model_key, "Proxy_Data": "None", "Method": "FP32", "Bits": "32/32", "Quant_Scope": "None",
                "Normal_Acc": f"{fp32_n:.4f}", "Counter_Acc": f"{fp32_c:.4f}", 
                "Rel_Gap_Q": f"{m1_gap:.4f}", "Delta_RSG": "0.0000",
                "Drop_Normal": "0.0000", "Drop_Counter": "0.0000", "Added_Vuln": "0.0000"
            })

        # Proxy Loop
        proxy_list = ["CC3M", "YFCC", "SBU"] 
        for proxy_name in proxy_list:
            shard_path = PROXY_SHARDS.get(proxy_name)
            if not os.path.exists(shard_path): continue

            print(f"\n>> [{proxy_name}] Adapting...")
            proxy_loader = create_proxy_loader(shard_path, preprocess)
            
            # --- METHODS DEFINITION ---
            methods = [
                # 1. PTQ
                ("Simple PTQ", apply_simple_ptq, {}),
                
                # 2. QAT (Contrastive Only)
                ("QAT (Contrastive)", apply_quantization_aware_training, {
                    'learning_rate': 1e-6, 
                    'main_loss_weight': 1.0, 'distill_weight': 0.0
                }),
                
                # 3. Rotation + LSQ (Hybrid)
                ("Rot+LSQ (Hybrid)", apply_rotation_lsq, {
                    'learning_rate': 1e-6, 'lsq_learning_rate': 1e-4, 
                    'main_loss_weight': 0.5, 'distill_weight': 0.5
                })
            ]

            for m_name, apply_fn, kwargs in methods:
                for w_bit, a_bit in BIT_WIDTHS:
                    print(f"   > Method: {m_name} | W{w_bit}A{a_bit}")
                    
                    for qt_text in TEXT_QUANT_MODES:
                        scope = "Vis+Txt" if qt_text else "VisOnly"
                        
                        student = copy.deepcopy(master_model).to('cpu').to(TARGET_DEVICE)
                        run_kwargs = {
                            'target_device': TARGET_DEVICE, 'tokenizer': tokenizer, 'prompts': None, 
                            'quantize_text': qt_text, 'weight_bits': w_bit, 'act_bits': a_bit, **kwargs
                        }

                        try:
                            if "QAT" in m_name or "LSQ" in m_name:
                                teacher = copy.deepcopy(master_model).to(TARGET_DEVICE).eval()
                                student = apply_fn(student, training_dataloader=proxy_loader, calibration_dataloader=cal_loader, teacher=teacher, **run_kwargs)
                                del teacher
                            else:
                                student = apply_fn(student, calibration_dataloader=cal_loader, **run_kwargs)

                            stud_txt = get_text_features(student, tokenizer, class_names)
                            acc_n = evaluate_accuracy(student, normal_loader, stud_txt)
                            acc_c = evaluate_accuracy(student, counter_loader, stud_txt)
                            
                            q_gap, d_rsg, d_n, d_c, vuln = calculate_all_metrics(acc_n, acc_c, fp32_n, fp32_c)
                            
                            print_result_row(m_name, scope, acc_n, acc_c, d_rsg, vuln)

                            with open(csv_name, 'a', newline='') as f:
                                csv.DictWriter(f, fieldnames=csv_fields).writerow({
                                    "Model": model_key, "Proxy_Data": proxy_name, "Method": m_name, 
                                    "Bits": f"{w_bit}/{a_bit}", "Quant_Scope": scope,
                                    "Normal_Acc": f"{acc_n:.4f}", "Counter_Acc": f"{acc_c:.4f}", 
                                    "Rel_Gap_Q": f"{q_gap:.4f}", "Delta_RSG": f"{d_rsg:.4f}",
                                    "Drop_Normal": f"{d_n:.4f}", "Drop_Counter": f"{d_c:.4f}", "Added_Vuln": f"{vuln:.4f}"
                                })

                        except Exception as e:
                            print(f"     [FAILED] {e}")

                        del student
                        cleanup()

        del master_model
        cleanup()

if __name__ == "__main__":
    main()