import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score
import pandas as pd


# Load data
def load_npz_split(path):
    d = np.load(path)
    X = d["X"]
    y = d["y"]
    X = np.asarray(X, dtype=np.float32)
    y = np.asarray(y, dtype=np.int64)
    if X.ndim == 1:  # edge case: single sample
        X = X[None, :]
    return X, y


# Data augmentation function
def augment_data(X):
    noise = np.random.normal(0, 0.01, X.shape)  # Adding Gaussian noise
    return X + noise


# Define the neural network model with an additional hidden layer and dropout
class ECGClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(ECGClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)  # Increased neurons
        self.dropout1 = nn.Dropout(0.5)  # Dropout layer
        self.fc2 = nn.Linear(512, 256)  # Increased neurons
        self.dropout2 = nn.Dropout(0.5)  # Dropout layer
        self.fc3 = nn.Linear(256, 128)
        self.dropout3 = nn.Dropout(0.5)  # Dropout layer
        self.fc4 = nn.Linear(128, num_classes)  # Output layer

    def forward(self, x):
        x = self.dropout1(nn.ReLU()(self.fc1(x)))
        x = self.dropout2(nn.ReLU()(self.fc2(x)))
        x = self.dropout3(nn.ReLU()(self.fc3(x)))
        x = self.fc4(x)
        return x


# Load datasets
X_train, y_train = load_npz_split("./input/train.npz")
X_val, y_val = load_npz_split("./input/val.npz")
X_test, _ = load_npz_split("./input/test.npz")

# Augment training data
X_train_augmented = augment_data(X_train)

# Hyperparameters
input_size = X_train.shape[1]
num_classes = len(np.unique(y_train))
num_epochs = 20
batch_size = 64
learning_rate = 0.001
weight_decay = 0.01  # Weight decay for AdamW
patience = 5  # Early stopping patience
max_lr = 0.001  # Max learning rate for CLR
base_lr = 0.0001  # Base learning rate for CLR

# K-Fold Cross Validation
kf = KFold(n_splits=5)
f1_scores = []

for train_index, val_index in kf.split(X_train_augmented):
    X_kf_train, X_kf_val = X_train_augmented[train_index], X_train_augmented[val_index]
    y_kf_train, y_kf_val = y_train[train_index], y_train[val_index]

    model = ECGClassifier(input_size, num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(), lr=learning_rate, weight_decay=weight_decay
    )
    best_val_loss = float("inf")
    epochs_without_improvement = 0

    # Training
    model.train()
    for epoch in range(num_epochs):
        # Cyclic Learning Rate
        cycle = np.floor(1 + epoch / (2 * num_epochs / 3))
        x = np.abs(epoch / (num_epochs / 3) - 2 * cycle + 1)
        lr = base_lr + (max_lr - base_lr) * np.maximum(0, (1 - x))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

        for i in range(0, len(X_kf_train), batch_size):
            inputs = torch.tensor(
                X_kf_train[i : i + batch_size], dtype=torch.float32
            )  # Ensure float32
            labels = torch.tensor(
                y_kf_train[i : i + batch_size], dtype=torch.int64
            )  # Ensure int64
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            val_outputs = model(
                torch.tensor(X_kf_val, dtype=torch.float32)
            )  # Ensure float32
            val_loss = criterion(
                val_outputs, torch.tensor(y_kf_val, dtype=torch.int64)
            )  # Ensure int64

        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f"Early stopping at epoch {epoch + 1}")
                break

    # Validation
    model.eval()
    with torch.no_grad():
        val_outputs = model(
            torch.tensor(X_kf_val, dtype=torch.float32)
        )  # Ensure float32
        _, predicted = torch.max(val_outputs, 1)
        f1 = f1_score(y_kf_val, predicted.numpy(), average="weighted")
        f1_scores.append(f1)

# Average F1 Score
average_f1_score = np.mean(f1_scores)
print(f"Average F1 Score: {average_f1_score}")

# Predictions on the test set
model.eval()
with torch.no_grad():
    test_outputs = model(torch.tensor(X_test, dtype=torch.float32))  # Ensure float32
    _, test_predictions = torch.max(test_outputs, 1)

# Save predictions to submission file
submission_df = pd.DataFrame(
    {"id": np.arange(len(test_predictions)), "label": test_predictions.numpy()}
)
submission_df.to_csv("./working/submission.csv", index=False)
