import argparse
from pathlib import Path
from typing import Dict, List
import pickle

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import torchvision.models as models
from tqdm import tqdm
import numpy as np

@torch.no_grad()
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data-root", type=str, default="path/to/imagenet/val")
    ap.add_argument("--out-dir", type=str, default="outputs/features")
    ap.add_argument("--layers", nargs="+", default=["layer1.2", "layer2.3", "layer3.5", "layer4.2"],
                    help="Exact module names to hook (e.g., layer4.2)")
    ap.add_argument("--batch-size", type=int, default=512)
    ap.add_argument("--workers", type=int, default=8)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--max-images", type=int, default=None, help="Optional limit for quick tests")
    args = ap.parse_args()

    tfm = T.Compose([
        T.Resize(256), T.CenterCrop(224), T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    ds = ImageFolder(args.data_root, transform=tfm)
    loader = DataLoader(ds, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)

    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    model.eval().to(args.device)

    handles = []
    store: Dict[str, torch.Tensor] = {}
    def hook_for(name):
        def _fn(_m, _i, o):
            store[name] = o.detach().cpu()
        return _fn

    for name, m in model.named_modules():
        if name in args.layers:
            handles.append(m.register_forward_hook(hook_for(name)))

    out_dir = Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    buffer: Dict[str, List[np.ndarray]] = {}
    gap = nn.AdaptiveAvgPool2d((1,1))

    total = len(ds) if args.max_images is None else min(len(ds), args.max_images)
    processed = 0
    for images, _ in tqdm(loader, desc="Extracting activations"):
        if processed >= total: break
        if images.size(0) > (total - processed):
            images = images[:(total - processed)]
        images = images.to(args.device, non_blocking=True)
        store.clear(); _ = model(images)
        bs = images.size(0)
        for lname, t in store.items():
            for i in range(bs):
                x = t[i].unsqueeze(0)
                pooled = gap(x).squeeze().to(torch.float32).numpy()
                buffer.setdefault(lname, []).append(pooled)
        processed += bs

    for lname, acts in buffer.items():
        out_path = out_dir / f"{lname.replace('.', '_')}_avgpool.pkl"
        with open(out_path, "wb") as f: pickle.dump(acts, f)
        print(f"[SAVED] {lname} -> {out_path}  (N={len(acts)}, dim={len(acts[0]) if acts else 'NA'})")

    for h in handles: h.remove()

if __name__ == "__main__":
    main()
