import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# ======== Model Definition ========== #
class ImprovedResNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        )
        self.classifier = nn.Linear(128 * 4 * 4, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# ======== Trigger Injection ========== #
def add_backdoor_trigger(images):
    images = images.clone()
    images[:, :, -5:, -5:] = 1.0
    return images

# ======== Benign Training ========== #
def pretrain(model, dataloader, device, epochs=10):
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    model.train()
    for epoch in range(epochs):
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

# ======== Simplified POLAR ========== #
class POLARDemo:
    def __init__(self, model, device):
        self.model = model
        self.device = device

    def select_layers(self):
        num_layers = sum(1 for _ in self.model.parameters())
        return torch.bernoulli(torch.full((num_layers,), 0.5))

    def apply_attack(self, selected_mask, target_class=0):
        with torch.no_grad():
            for i, param in enumerate(self.model.parameters()):
                if selected_mask[i] == 1:
                    param.add_(0.01 * torch.randn_like(param))
            self.model.classifier.weight[target_class] += 2.0
            self.model.classifier.bias[target_class] += 1.0

    def evaluate(self, dataloader):
        self.model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                pred = self.model(x).argmax(dim=1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        return correct / total

# ======== Dataset Loader ========== #
def get_data_loaders():
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    val_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    train_loader = DataLoader(Subset(train_dataset, list(range(2048))), batch_size=64, shuffle=True)
    val_loader = DataLoader(Subset(val_dataset, list(range(512))), batch_size=64, shuffle=False)
    return train_loader, val_loader

# ======== BSR Evaluation ========== #
def evaluate_bsr(model, dataloader, device, trigger_fn, target_label):
    model.eval()
    total, success = 0, 0
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = trigger_fn(inputs).to(device)
            labels = torch.full((inputs.size(0),), target_label, dtype=torch.long).to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            success += (preds == labels).sum().item()
            total += inputs.size(0)
    return success / total if total > 0 else 0.0

# ======== Main Logic ========== #
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ImprovedResNet().to(device)
    train_loader, val_loader = get_data_loaders()

    pretrain(model, train_loader, device)
    polar = POLARDemo(model, device)
    clean_acc = polar.evaluate(val_loader)
    print(f"[Before Attack] Clean Accuracy: {clean_acc:.4f}")

    layer_mask = polar.select_layers()
    polar.apply_attack(layer_mask, target_class=0)

    post_acc = polar.evaluate(val_loader)
    bsr = evaluate_bsr(model, val_loader, device, add_backdoor_trigger, target_label=0)

    print(f"[After Attack] Accuracy: {post_acc:.4f}")
    print(f"[After Attack] Backdoor Success Rate (BSR): {bsr * 100:.2f}%")

if __name__ == "__main__":
    main()

