# Nuisance-Prompt Tuning (NPT) Implementation Plan

## Overview
This plan outlines the implementation of Nuisance-Prompt Tuning (NPT) for few-shot out-of-distribution (OOD) detection, extending the baseline LoCoOp method.

## Research Hypothesis
NPT introduces a dedicated learnable "nuisance" prompt vector alongside class prompts to explicitly capture and repel background (ID-irrelevant) features. By weighting patch-level supervision via CLIP's self-attention relevance scores, we can softly assign each patch a degree of backgroundness without fixed thresholds.

## Key Components to Implement

### 1. NPT Prompt Learner (`NPTPromptLearner`)
- **Base**: Extend existing `PromptLearner` class
- **Addition**: Learn M class context vectors + 1 nuisance context vector `b`
- **Parameters**:
  - `self.ctx`: Class context vectors (M x n_ctx x ctx_dim)
  - `self.nuisance_ctx`: Nuisance context vector (1 x n_ctx x ctx_dim)
- **Forward**: Return both class prompts and nuisance prompt

### 2. Attention Extraction (`extract_attention_weights`)
- **Input**: CLIP Vision Transformer features
- **Process**: 
  - Extract attention weights from [CLS] token to patch tokens
  - From last transformer layer's multi-head attention
  - Normalize to [0,1] as relevance scores r_i
- **Output**: Patch background weights w_i = 1 - r_i

### 3. NPT Loss Functions (`NPTLosses`)
- **L_global**: Standard few-shot cross-entropy on global image features vs M class prompts
- **L_patch**: Weighted patch-level cross-entropy:
  - Σ_i w_i * CE([sim(f_i, g₁),...,sim(f_i,g_M),sim(f_i,b)], label=BACKGROUND)
- **L_margin**: Margin loss to repel nuisance from class prompts:
  - Σ_m max(0, sim(b,g_m) - margin)
- **Total**: L = L_global + λ_patch * L_patch + λ_margin * L_margin

### 4. NPT Custom CLIP (`NPTCustomCLIP`)
- **Base**: Extend existing `CustomCLIP` class
- **Components**:
  - NPTPromptLearner instead of PromptLearner
  - Modified forward pass to extract attention weights
  - Loss computation with NPT-specific losses
- **Forward**:
  - Extract global and local image features
  - Get attention weights for patch weighting
  - Compute similarities with class and nuisance prompts
  - Return logits and attention weights

### 5. Main Entry Script (`proposed_method.py`)
- **Base**: Copy baseline.py structure
- **Modifications**:
  - Use NPTCustomCLIP instead of CustomCLIP
  - Add NPT-specific hyperparameters (λ_patch, λ_margin, margin)
  - Modified training loop with NPT losses
- **Hyperparameters**:
  - λ_patch: 0.25 (default)
  - λ_margin: 0.25 (default) 
  - margin: 0.2 (default)

## Implementation Strategy

### Phase 1: Core Components
1. Implement `NPTPromptLearner` with nuisance prompt
2. Implement attention extraction from ViT
3. Implement NPT loss functions

### Phase 2: Model Integration
1. Create `NPTCustomCLIP` class
2. Integrate attention extraction into forward pass
3. Modify loss computation

### Phase 3: Entry Script
1. Create `proposed_method.py`
2. Add hyperparameter handling
3. Integrate NPT components into training/evaluation pipeline

### Phase 4: Testing & Debugging
1. Test with small dataset subset
2. Verify attention extraction works correctly
3. Debug loss computations
4. Validate output format compatibility

## Key Technical Details

### Attention Extraction
```python
# Extract attention from ViT transformer
# From VisionTransformer.forward() -> transformer(x) -> ResidualAttentionBlock
# Use attention_weight() method to get weights between [CLS] and patches
attention_weights = transformer.resblocks[-1].attention_weight(x)
cls_to_patch_attn = attention_weights[:, :, 0, 1:]  # [CLS] to patch attention
relevance = torch.softmax(cls_to_patch_attn, dim=-1)  # Normalize to [0,1]
background_weights = 1 - relevance
```

### Nuisance Prompt Integration
```python
# In NPTPromptLearner.forward()
class_prompts = self.construct_class_prompts()  # [n_cls, seq_len, dim]
nuisance_prompt = self.construct_nuisance_prompt()  # [1, seq_len, dim]
all_prompts = torch.cat([class_prompts, nuisance_prompt], dim=0)
```

### Loss Computation
```python
# Global loss: standard cross-entropy
loss_global = F.cross_entropy(global_logits, labels)

# Patch loss: attention-weighted background classification
patch_logits_with_nuisance = compute_patch_similarities(patch_features, all_prompts)
background_labels = torch.full_like(labels, n_classes)  # Nuisance class index
loss_patch = weighted_cross_entropy(patch_logits_with_nuisance, background_labels, background_weights)

# Margin loss: repel nuisance from class prompts
nuisance_class_sims = compute_nuisance_class_similarities(nuisance_prompt, class_prompts)
loss_margin = torch.mean(torch.clamp(nuisance_class_sims - margin, min=0))

# Total loss
loss = loss_global + lambda_patch * loss_patch + lambda_margin * loss_margin
```

## Expected Challenges & Solutions

### Challenge 1: Attention Extraction Complexity
- **Issue**: CLIP ViT attention extraction may be complex
- **Solution**: Use existing attention mechanisms, extract from last layer

### Challenge 2: Memory Usage
- **Issue**: Storing attention weights for all patches increases memory
- **Solution**: Process in batches, use gradient checkpointing if needed

### Challenge 3: Hyperparameter Sensitivity
- **Issue**: Multiple hyperparameters (λ_patch, λ_margin, margin) need tuning
- **Solution**: Start with research paper defaults, add validation

### Challenge 4: Compatibility
- **Issue**: Maintaining compatibility with baseline output format
- **Solution**: Preserve exact same output schema (scores.npz, results.json)

## Success Metrics
- [ ] Code runs without errors
- [ ] Outputs match baseline format (scores.npz, results.json)
- [ ] NPT losses computed correctly
- [ ] Attention weights extracted successfully
- [ ] Performance improvements on OOD benchmarks (FPR95, AUROC)