import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

# Custom Dataset that holds features and labels
class FeatureDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

# Load the saved data dictionaries
clean_data = torch.load("test.pt")
bd_data = torch.load("test_backdoored.pt")

# Extract image features (assumed to be of shape [N, feature_dim])
clean_features = clean_data["image_features"]
bd_features = bd_data["image_features"]

# Create labels: 0 for clean, 1 for backdoored
clean_labels = torch.zeros(clean_features.size(0), dtype=torch.long)
bd_labels = torch.ones(bd_features.size(0), dtype=torch.long)

# Combine the features and labels from both sets
X = torch.cat([clean_features, bd_features], dim=0)
y = torch.cat([clean_labels, bd_labels], dim=0)

# Create a dataset
dataset = FeatureDataset(X, y)

# Split the dataset into 80% training and 20% validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
print(len(train_loader))
print(len(val_loader))
# exit()

# Define a simple classifier model (an MLP)
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # 2 output classes: clean and backdoored
        )
    
    def forward(self, x):
        return self.fc(x)

# Assume your image feature dimension is, for example, 512
input_dim = X.shape[1]
print(input_dim)
model = SimpleClassifier(input_dim)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch_features, batch_labels in train_loader:
        optimizer.zero_grad()
        # Ensure the features are float32
        outputs = model(batch_features.float())
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")

model.load_state_dict(torch.load('detector_caltech.pth'))
# Evaluation on the validation set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch_features, batch_labels in val_loader:
        outputs = model(batch_features.float())
        # Get predicted class (index with highest logit)
        _, predicted = torch.max(outputs, dim=1)
        print(outputs)
        print(predicted)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

val_accuracy = correct / total
print(f"Validation Accuracy: {val_accuracy:.4f}")


# torch.save(model.state_dict(), 'detector_caltech.pth')