import torch
import torch.nn as nn
import torch.nn.functional as F
import open_clip
import numpy as np
import copy
import re
from tqdm import tqdm
import scripts.data_setup as data_setup
import scripts.config as config
from transformers import XLMRobertaTokenizer
from open_clip.tokenizer import HFTokenizer

# --- CONFIGURATION ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 256 
DATA_PERCENT = 1.0  # FULL ImageNet Validation Set (50k images)

# Enable CuDNN benchmark for max throughput
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

# --- MODEL ZOO ---
MODEL_ZOO = {
    "ALIGN_RoBERTa_ViT-B":     {'type': 'open_clip', 'arch': 'xlm-roberta-base-ViT-B-32', 'data': 'laion5b_s13b_b90k'},
    "SigLIP_ViT-B-16":         {'type': 'open_clip', 'arch': 'ViT-B-16-SigLIP', 'data': 'webli'},
    "CLIP_ViT-L-14_Laion2B":   {'type': 'open_clip', 'arch': 'ViT-L-14', 'data': 'laion2b_s32b_b82k'},
    "CLIP_ViT-L-14_OpenAI":    {'type': 'open_clip', 'arch': 'ViT-L-14-quickgelu', 'data': 'openai'},
    "EVA02_B-16":              {'type': 'open_clip', 'arch': 'EVA02-B-16', 'data': 'merged2b_s8b_b131k'},
    "ConvNeXt_Base":           {'type': 'open_clip', 'arch': 'convnext_base', 'data': 'laion400m_s13b_b51k'},
    "CLIP_ViT-B-32_Laion2B":   {'type': 'open_clip', 'arch': 'ViT-B-32', 'data': 'laion2b_s34b_b79k'},
    "DFN_ViT-B-32":            {'type': 'open_clip', 'arch': 'ViT-B-32', 'data': 'datacomp_xl_s13b_b90k'},
    "CLIP_ViT-B-32_OpenAI":    {'type': 'open_clip', 'arch': 'ViT-B-32-quickgelu', 'data': 'openai'},
    "CoCa_ViT-B-32":           {'type': 'open_clip', 'arch': 'coca_ViT-B-32', 'data': 'laion2b_s13b_b90k'},
}

# Targeted Dropout rates
DROPOUT_RATES = [0.01, 0.015, 0.025]

def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text

def get_imagenet_loader():
    print("Loading ImageNet Validation Set (Full)...")
    # Get standard preprocess
    _, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='openai')
    # Use your data_setup utility
    _, test_loader, class_names, template = data_setup.get_dataset_loaders("imagenet1kval", preprocess, get_train=False)
    return test_loader, class_names, template

def get_safe_tokenizer(arch_name):
    """Handles the RoBERTa tokenizer logic for ALIGN variants."""
    if "roberta" in arch_name.lower():
        try:
            hf_name = "xlm-roberta-base" 
            internal_tokenizer = XLMRobertaTokenizer.from_pretrained(hf_name)
            tokenizer = HFTokenizer.__new__(HFTokenizer)
            tokenizer.tokenizer = internal_tokenizer
            tokenizer.context_length = 77
            tokenizer.clean_fn = whitespace_clean
            tokenizer.tokenizer_mode = 'slow'
            tokenizer.strip_sep_token = False
            return tokenizer
        except Exception:
            return open_clip.get_tokenizer(arch_name)
    return open_clip.get_tokenizer(arch_name)

# --- HOOK MECHANISM ---
class DropoutHook:
    def __init__(self, p):
        self.p = p
    def __call__(self, module, input, output):
        # training=True forces dropout during eval()
        return F.dropout(output, p=self.p, training=True)

def attach_dropout_hooks(model, p, vision_only=False):
    hooks = []
    count = 0
    # Capture all possible block types in the Zoo
    target_block_types = ["ResidualAttentionBlock", "ConvNeXtBlock", "Block", "BertLayer"] 
    
    for name, module in model.named_modules():
        if any(t in module.__class__.__name__ for t in target_block_types):
            is_vision = "visual" in name.lower()
            if vision_only and not is_vision:
                continue
            h = module.register_forward_hook(DropoutHook(p))
            hooks.append(h)
            count += 1
    return hooks, count

@torch.no_grad()
def evaluate_acc(model, loader, class_names, template, device, arch_name):
    model.eval()
    tokenizer = get_safe_tokenizer(arch_name)
    texts = [template.format(c) for c in class_names]
    text_tokens = tokenizer(texts).to(device)
    
    # Pre-compute text features once per evaluation run
    with torch.amp.autocast('cuda'):
        text_features = model.encode_text(text_tokens)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    correct, total = 0, 0
    
    # Progress bar for the full ImageNet pass
    pbar = tqdm(loader, desc="Evaluating ImageNet", leave=False)
    for images, labels in pbar:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        
        with torch.amp.autocast('cuda'):
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            logits = 100.0 * image_features @ text_features.T
            preds = logits.argmax(dim=-1)
        
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return 100.0 * correct / total

# --- MAIN ---
if __name__ == "__main__":
    loader, class_names, template = get_imagenet_loader()
    final_results = {}

    print(f"\n{'='*80}\nFULL IMAGENET REDUNDANCY PROBE: MIXED DROPOUT\n{'='*80}")
    print(f"Trials: 4 (2 Mixed, 2 Vision-Only) | Dataset: 100% ImageNet Val | FP16\n")

    for display_name, cfg in MODEL_ZOO.items():
        arch, data_tag = cfg['arch'], cfg['data']
        print(f"Processing Model: {display_name}...")
        
        try:
            model, _, _ = open_clip.create_model_and_transforms(arch, pretrained=data_tag)
            model = model.to(DEVICE)
        except Exception as e:
            print(f"  [Error] Failed to load {display_name}: {e}"); continue

        model_results = {'sparsities': [0.0], 'rel_acc': [1.0], 'abs_acc': []}
        
        # 1. Baseline
        base_acc = evaluate_acc(model, loader, class_names, template, DEVICE, arch)
        print(f"  > Baseline Accuracy: {base_acc:.2f}%")
        model_results['abs_acc'].append(base_acc)

        # 2. Probe Rates
        for p in DROPOUT_RATES:
            print(f"  > Probing rate p={p}...")
            trial_accs = []
            
            # --- Trial 1 & 2: Vision + Text Dropout ---
            hooks, _ = attach_dropout_hooks(model, p, vision_only=False)
            trial_accs.append(evaluate_acc(model, loader, class_names, template, DEVICE, arch))
            trial_accs.append(evaluate_acc(model, loader, class_names, template, DEVICE, arch))
            for h in hooks: h.remove()
            
            # --- Trial 3 & 4: Vision Only Dropout ---
            hooks, _ = attach_dropout_hooks(model, p, vision_only=True)
            trial_accs.append(evaluate_acc(model, loader, class_names, template, DEVICE, arch))
            trial_accs.append(evaluate_acc(model, loader, class_names, template, DEVICE, arch))
            for h in hooks: h.remove()
            
            avg_acc = np.mean(trial_accs)
            rel_acc = avg_acc / base_acc
            
            model_results['sparsities'].append(p)
            model_results['abs_acc'].append(avg_acc)
            model_results['rel_acc'].append(rel_acc)
            
            print(f"    - Avg Acc: {avg_acc:.2f}% (Rel Retention: {rel_acc:.4f})")
        
        final_results[display_name] = model_results
        del model
        torch.cuda.empty_cache()
        print("-" * 60)

    print("\n\n" + "="*30 + " RESULTS FOR CORRELATION PLOT " + "="*30)
    print("RESULTS_DATA = " + str(final_results))
    print("="*80)