import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Load data
train_data = np.load("./input/train_data.npy")
train_labels = np.load("./input/train_label.npy")
val_data = np.load("./input/val_data.npy")
val_labels = np.load("./input/val_label.npy")

# Convert to PyTorch tensors
X_train = torch.tensor(train_data, dtype=torch.float32)
y_train = torch.tensor(
    train_labels - 1, dtype=torch.long
)  # Adjust labels to be zero-indexed
X_val = torch.tensor(val_data, dtype=torch.float32)
y_val = torch.tensor(val_labels - 1, dtype=torch.long)

# Create DataLoader for training
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


# Define a feedforward neural network with batch normalization
class ImprovedNN(nn.Module):
    def __init__(self):
        super(ImprovedNN, self).__init__()
        self.fc1 = nn.Linear(46, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.fc3 = nn.Linear(64, 24)

    def forward(self, x):
        x = torch.relu(self.bn1(self.fc1(x)))
        x = torch.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x


# Initialize model, loss function, and optimizer
model = ImprovedNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, "min", patience=5, verbose=True)

# Train the model with early stopping
epochs = 100
best_accuracy = 0
patience = 10
patience_counter = 0

for epoch in range(epochs):
    model.train()
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

    # Validate the model
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val)
        _, predicted = torch.max(val_outputs, 1)
        accuracy = accuracy_score(y_val.numpy(), predicted.numpy())
        scheduler.step(loss)  # Step the scheduler based on the current loss

    # Check for early stopping
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch + 1}")
        break

# Save predictions in submission.csv
submission = np.column_stack(
    (np.arange(len(predicted)), predicted.numpy() + 1)
)  # Adjust back to 1-indexed
np.savetxt(
    "./working/submission.csv",
    submission,
    delimiter=",",
    header="Id,Label",
    comments="",
    fmt="%d",
)

# Print the evaluation metric
print(f"Validation Accuracy: {accuracy:.4f}")
