import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import random
import numpy as np

# -----------------------------------------
# 1) Set random seeds for reproducibility
# -----------------------------------------
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class MLPResidual3(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=512, output_dim=512):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.residual = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        out = self.relu(self.fc1(x))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        return out + self.residual(x)

# -----------------------------------------
# Utility to evaluate accuracy
# -----------------------------------------
def evaluate_accuracy(model, loader, device):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total

# -----------------------------------------
# Main training & unlearning procedure
# -----------------------------------------
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_epochs = 100
    batch_size = 128
    lr = 0.001

    # Data transforms
    transform_train = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Load entire CIFAR-10 train and test sets
    train_full = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
    test_set   = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

    # Select 1% (~500 samples) for unlearning
    n_total = len(train_full)
    n_unlearn = int(0.01 * n_total)
    all_indices = list(range(n_total))
    random.shuffle(all_indices)
    unlearn_idx = set(all_indices[:n_unlearn])
    keep_idx = [i for i in all_indices if i not in unlearn_idx]

    retrain_ds = Subset(train_full, keep_idx)
    unlearn_ds = Subset(train_full, list(unlearn_idx))

    train_loader   = DataLoader(retrain_ds, batch_size=batch_size, shuffle=True,  num_workers=4)
    unlearn_loader = DataLoader(unlearn_ds,   batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader    = DataLoader(test_set,      batch_size=batch_size, shuffle=False, num_workers=4)

    # Build model
    resnet = models.resnet18(pretrained=False)
    backbone = nn.Sequential(*list(resnet.children())[:-1])
    model = nn.Sequential(
        backbone,
        nn.Flatten(),
        MLPResidual(),
        nn.Linear(512, 10)
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Training loop
    for epoch in range(1, num_epochs+1):
        model.train()
        running_loss = 0.0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")

    # Final evaluations
    test_acc    = evaluate_accuracy(model, test_loader, device)
    remain_acc  = evaluate_accuracy(model, train_loader, device)
    unlearn_acc = evaluate_accuracy(model, unlearn_loader, device)

    print(f"Test Set Accuracy:      {test_acc:.2f}%")
    print(f"Remaining Set Accuracy: {remain_acc:.2f}%")
    print(f"Unlearn Set Accuracy:   {unlearn_acc:.2f}%")

if __name__ == "__main__":
    main()
