import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

# Custom Dataset that holds features and labels
class FeatureDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

# Load the saved data dictionaries
clean_data_val = torch.load("test.pt")
bd_data_val = torch.load("test_backdoored.pt")

clean_data = torch.load("test_flowers.pt")
bd_data = torch.load("test_backdoored_flowers.pt")

# Extract image features (assumed to be of shape [N, feature_dim])
clean_features = clean_data["image_features"]
bd_features = bd_data["image_features"]

clean_features_val = clean_data_val["image_features"]
bd_features_val = bd_data_val["image_features"]

# Create labels: 0 for clean, 1 for backdoored
clean_labels = torch.zeros(clean_features.size(0), dtype=torch.long)
bd_labels = torch.ones(bd_features.size(0), dtype=torch.long)

clean_labels_val = torch.zeros(clean_features_val.size(0), dtype=torch.long)
bd_labels_val = torch.ones(bd_features_val.size(0), dtype=torch.long)

# Combine the features and labels from both sets
X = torch.cat([clean_features, bd_features], dim=0)
y = torch.cat([clean_labels, bd_labels], dim=0)

# Combine the features and labels from both sets
X_val = torch.cat([clean_features_val, bd_features_val], dim=0)
y_val = torch.cat([clean_labels_val, bd_labels_val], dim=0)

# Create a dataset
dataset = FeatureDataset(X, y)
dataset_val = FeatureDataset(X_val, y_val)

# Split the dataset into 80% training and 20% validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders for training and validation
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset_val, batch_size=64, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
print(len(train_loader))
print(len(val_loader))
# exit()

# Define a simple classifier model (an MLP)
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # 2 output classes: clean and backdoored
        )
    
    def forward(self, x):
        return self.fc(x)

# Assume your image feature dimension is, for example, 512
input_dim = X.shape[1]
model = SimpleClassifier(input_dim)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch_features, batch_labels in train_loader:
        optimizer.zero_grad()
        # Ensure the features are float32
        outputs = model(batch_features.float())
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}")

# Evaluation on the validation set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch_features, batch_labels in val_loader:
        outputs = model(batch_features.float())
        # Get predicted class (index with highest logit)
        _, predicted = torch.max(outputs, dim=1)
        # print("actual: ", batch_labels)
        # print("predicted: ", predicted)
        # print(list(zip(batch_labels.tolist(), predicted.tolist())))
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

val_accuracy = correct / total
print(f"Validation Accuracy: {val_accuracy:.4f}")
