import argparse
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision.models as models
import torch.nn as nn
import os


# -------------------------------
# 1) Argparse
# -------------------------------
parser = argparse.ArgumentParser(
    description="Extract CIFAR-10 train & test features with a trained ResNet18+MLPResidual"
)
parser.add_argument(
    "--weights-path",
    type=str,
    required=True,
    help="Path to the .pth file with trained model weights"
)
parser.add_argument(
    "--output-prefix",
    type=str,
    default="cifar10_resnet",
    help="Prefix for output .npy files"
)
args = parser.parse_args()

# -------------------------------
# 2) Device
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------------
# 3) MLPResidual
# -------------------------------
class MLPResidual(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=512, output_dim=512):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.residual = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        out = self.relu(self.fc1(x))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        return out + self.residual(x)

# -------------------------------
# 4) Load model + weights
# -------------------------------
print(f"Loading weights from {args.weights_path}")
resnet18 = models.resnet18(pretrained=False)
backbone = nn.Sequential(*list(resnet18.children())[:-1])
model = nn.Sequential(
    backbone,
    nn.Flatten(),
    MLPResidual(512, 512, 512),
    nn.Linear(512, 10)
)
model.load_state_dict(torch.load(args.weights_path, map_location=device))
model.to(device).eval()
feature_extractor = model[0]

# -------------------------------
# 5) Transforms, dataset & loader params
# -------------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225]
    )
])
DATA_DIR   = "./data"
BATCH_SIZE = 32

datasets_dict = {
    "train": datasets.CIFAR10(root=DATA_DIR, train=True,  download=True, transform=transform),
    "test":  datasets.CIFAR10(root=DATA_DIR, train=False, download=True, transform=transform)
}

# -------------------------------
# 6) Extract & save
# -------------------------------
os.makedirs("RSI-CFRL", exist_ok=True)

for split, ds in datasets_dict.items():
    print(f"Extracting features for {split} split...")
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)
    feats, lbls = [], []

    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc=split):
            imgs = imgs.to(device)
            out = feature_extractor(imgs).view(imgs.size(0), -1)
            feats.append(out.cpu().numpy())
            lbls.append(labels.numpy())

    feats = np.concatenate(feats, axis=0)
    lbls  = np.concatenate(lbls, axis=0)

    feat_file  = f"{args.output_prefix}_{split}_features.npy"
    label_file = f"{args.output_prefix}_{split}_labels.npy"
    np.save(feat_file, feats)
    np.save(label_file, lbls)

    print(f"Saved:\n"
          f" • {feat_file} (shape: {feats.shape})\n"
          f" • {label_file} (shape: {lbls.shape})")

print("Feature extraction complete.")  
